mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 10:53:02 +01:00
Accept attached files in the chat API
- weave through all subsequent subcalls to models, where relevant, and save to conversation log
This commit is contained in:
parent
ecc81e06a7
commit
b8ed98530f
1 changed files with 23 additions and 3 deletions
|
@ -45,7 +45,7 @@ from khoj.routers.helpers import (
|
||||||
aget_relevant_output_modes,
|
aget_relevant_output_modes,
|
||||||
construct_automation_created_message,
|
construct_automation_created_message,
|
||||||
create_automation,
|
create_automation,
|
||||||
gather_attached_files,
|
gather_raw_attached_files,
|
||||||
generate_excalidraw_diagram,
|
generate_excalidraw_diagram,
|
||||||
generate_summary_from_files,
|
generate_summary_from_files,
|
||||||
get_conversation_command,
|
get_conversation_command,
|
||||||
|
@ -71,7 +71,12 @@ from khoj.utils.helpers import (
|
||||||
get_device,
|
get_device,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
)
|
)
|
||||||
from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, LocationData
|
from khoj.utils.rawconfig import (
|
||||||
|
ChatRequestBody,
|
||||||
|
FileFilterRequest,
|
||||||
|
FilesFilterRequest,
|
||||||
|
LocationData,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize Router
|
# Initialize Router
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -566,6 +571,7 @@ async def chat(
|
||||||
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||||
timezone = body.timezone
|
timezone = body.timezone
|
||||||
raw_images = body.images
|
raw_images = body.images
|
||||||
|
raw_attached_files = body.files
|
||||||
|
|
||||||
async def event_generator(q: str, images: list[str]):
|
async def event_generator(q: str, images: list[str]):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
@ -577,6 +583,7 @@ async def chat(
|
||||||
q = unquote(q)
|
q = unquote(q)
|
||||||
train_of_thought = []
|
train_of_thought = []
|
||||||
nonlocal conversation_id
|
nonlocal conversation_id
|
||||||
|
nonlocal raw_attached_files
|
||||||
|
|
||||||
tracer: dict = {
|
tracer: dict = {
|
||||||
"mid": turn_id,
|
"mid": turn_id,
|
||||||
|
@ -596,6 +603,11 @@ async def chat(
|
||||||
if uploaded_image:
|
if uploaded_image:
|
||||||
uploaded_images.append(uploaded_image)
|
uploaded_images.append(uploaded_image)
|
||||||
|
|
||||||
|
attached_files: Dict[str, str] = {}
|
||||||
|
if raw_attached_files:
|
||||||
|
for file in raw_attached_files:
|
||||||
|
attached_files[file.name] = file.content
|
||||||
|
|
||||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||||
nonlocal connection_alive, ttft, train_of_thought
|
nonlocal connection_alive, ttft, train_of_thought
|
||||||
if not connection_alive or await request.is_disconnected():
|
if not connection_alive or await request.is_disconnected():
|
||||||
|
@ -707,7 +719,7 @@ async def chat(
|
||||||
compiled_references: List[Any] = []
|
compiled_references: List[Any] = []
|
||||||
inferred_queries: List[Any] = []
|
inferred_queries: List[Any] = []
|
||||||
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
||||||
attached_file_context = await gather_attached_files(user, file_filters)
|
attached_file_context = gather_raw_attached_files(attached_files)
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||||
conversation_commands = await aget_relevant_information_sources(
|
conversation_commands = await aget_relevant_information_sources(
|
||||||
|
@ -833,6 +845,7 @@ async def chat(
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
|
raw_attached_files=raw_attached_files,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -878,6 +891,7 @@ async def chat(
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
|
raw_attached_files=raw_attached_files,
|
||||||
)
|
)
|
||||||
async for result in send_llm_response(llm_response):
|
async for result in send_llm_response(llm_response):
|
||||||
yield result
|
yield result
|
||||||
|
@ -900,6 +914,7 @@ async def chat(
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
attached_files=attached_file_context,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -1085,6 +1100,8 @@ async def chat(
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
|
attached_file_context=attached_file_context,
|
||||||
|
raw_attached_files=raw_attached_files,
|
||||||
)
|
)
|
||||||
content_obj = {
|
content_obj = {
|
||||||
"intentType": intent_type,
|
"intentType": intent_type,
|
||||||
|
@ -1144,6 +1161,8 @@ async def chat(
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
|
attached_file_context=attached_file_context,
|
||||||
|
raw_attached_files=raw_attached_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
|
@ -1172,6 +1191,7 @@ async def chat(
|
||||||
tracer,
|
tracer,
|
||||||
train_of_thought,
|
train_of_thought,
|
||||||
attached_file_context,
|
attached_file_context,
|
||||||
|
raw_attached_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send Response
|
# Send Response
|
||||||
|
|
Loading…
Reference in a new issue