Accept attached files in the chat API

- weave through all subsequent subcalls to models, where relevant, and save to conversation log
This commit is contained in:
sabaimran 2024-11-07 16:01:48 -08:00
parent ecc81e06a7
commit b8ed98530f

View file

@ -45,7 +45,7 @@ from khoj.routers.helpers import (
aget_relevant_output_modes, aget_relevant_output_modes,
construct_automation_created_message, construct_automation_created_message,
create_automation, create_automation,
gather_attached_files, gather_raw_attached_files,
generate_excalidraw_diagram, generate_excalidraw_diagram,
generate_summary_from_files, generate_summary_from_files,
get_conversation_command, get_conversation_command,
@ -71,7 +71,12 @@ from khoj.utils.helpers import (
get_device, get_device,
is_none_or_empty, is_none_or_empty,
) )
from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, LocationData from khoj.utils.rawconfig import (
ChatRequestBody,
FileFilterRequest,
FilesFilterRequest,
LocationData,
)
# Initialize Router # Initialize Router
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -566,6 +571,7 @@ async def chat(
country_code = body.country_code or get_country_code_from_timezone(body.timezone) country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone timezone = body.timezone
raw_images = body.images raw_images = body.images
raw_attached_files = body.files
async def event_generator(q: str, images: list[str]): async def event_generator(q: str, images: list[str]):
start_time = time.perf_counter() start_time = time.perf_counter()
@ -577,6 +583,7 @@ async def chat(
q = unquote(q) q = unquote(q)
train_of_thought = [] train_of_thought = []
nonlocal conversation_id nonlocal conversation_id
nonlocal raw_attached_files
tracer: dict = { tracer: dict = {
"mid": turn_id, "mid": turn_id,
@ -596,6 +603,11 @@ async def chat(
if uploaded_image: if uploaded_image:
uploaded_images.append(uploaded_image) uploaded_images.append(uploaded_image)
attached_files: Dict[str, str] = {}
if raw_attached_files:
for file in raw_attached_files:
attached_files[file.name] = file.content
async def send_event(event_type: ChatEvent, data: str | dict): async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft, train_of_thought nonlocal connection_alive, ttft, train_of_thought
if not connection_alive or await request.is_disconnected(): if not connection_alive or await request.is_disconnected():
@ -707,7 +719,7 @@ async def chat(
compiled_references: List[Any] = [] compiled_references: List[Any] = []
inferred_queries: List[Any] = [] inferred_queries: List[Any] = []
file_filters = conversation.file_filters if conversation and conversation.file_filters else [] file_filters = conversation.file_filters if conversation and conversation.file_filters else []
attached_file_context = await gather_attached_files(user, file_filters) attached_file_context = gather_raw_attached_files(attached_files)
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources( conversation_commands = await aget_relevant_information_sources(
@ -833,6 +845,7 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer, tracer=tracer,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
raw_attached_files=raw_attached_files,
) )
return return
@ -878,6 +891,7 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer, tracer=tracer,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
raw_attached_files=raw_attached_files,
) )
async for result in send_llm_response(llm_response): async for result in send_llm_response(llm_response):
yield result yield result
@ -900,6 +914,7 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
attached_files=attached_file_context,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -1085,6 +1100,8 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer, tracer=tracer,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
attached_file_context=attached_file_context,
raw_attached_files=raw_attached_files,
) )
content_obj = { content_obj = {
"intentType": intent_type, "intentType": intent_type,
@ -1144,6 +1161,8 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer, tracer=tracer,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
attached_file_context=attached_file_context,
raw_attached_files=raw_attached_files,
) )
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj)):
@ -1172,6 +1191,7 @@ async def chat(
tracer, tracer,
train_of_thought, train_of_thought,
attached_file_context, attached_file_context,
raw_attached_files,
) )
# Send Response # Send Response