diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index ea63d81f..e2d461f6 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -9,7 +9,7 @@ from datetime import datetime from enum import Enum from io import BytesIO from time import perf_counter -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import PIL.Image import requests @@ -29,7 +29,12 @@ from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter from khoj.utils import state -from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts +from khoj.utils.helpers import ( + ConversationCommand, + in_debug_mode, + is_none_or_empty, + merge_dicts, +) logger = logging.getLogger(__name__) model_to_prompt_size = { @@ -139,6 +144,41 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A return chat_history +def construct_tool_chat_history( + previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None +) -> Dict[str, list]: + chat_history: list = [] + inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: [] + if tool == ConversationCommand.Notes: + inferred_query_extractor = ( + lambda iteration: [c["query"] for c in iteration.context] if iteration.context else [] + ) + elif tool == ConversationCommand.Online: + inferred_query_extractor = ( + lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else [] + ) + elif tool == ConversationCommand.Code: + inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else [] + for iteration in previous_iterations: + chat_history += [ + { + "by": "you", + "message": iteration.query, + }, + { + "by": "khoj", + "intent": { + "type": "remember", + "inferred-queries": inferred_query_extractor(iteration), + "query": iteration.query, + }, + "message": iteration.summarizedResult, + }, + ] + + return {"chat": chat_history} + + class ChatEvent(Enum): START_LLM_RESPONSE = "start_llm_response" END_LLM_RESPONSE = "end_llm_response" diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 3ca0839b..83aecc7e 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -12,6 +12,7 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( InformationCollectionIteration, construct_iteration_history, + construct_tool_chat_history, remove_json_codeblock, ) from khoj.processor.tools.online_search import read_webpages, search_online @@ -147,7 +148,6 @@ async def execute_information_collection( code_results: Dict = dict() document_results: List[Dict[str, str]] = [] summarize_files: str = "" - inferred_queries: List[Any] = [] this_iteration = InformationCollectionIteration(tool=None, query=query) previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) @@ -174,7 +174,7 @@ async def execute_information_collection( document_results = [] async for result in extract_references_and_questions( request, - conversation_history, + construct_tool_chat_history(previous_iterations, ConversationCommand.Notes), this_iteration.query, 7, None, @@ -208,7 +208,7 @@ async def execute_information_collection( elif this_iteration.tool == ConversationCommand.Online: async for result in search_online( this_iteration.query, - conversation_history, + construct_tool_chat_history(previous_iterations, ConversationCommand.Online), location, user, send_status_func, @@ -228,7 +228,7 @@ async def execute_information_collection( try: async for result in read_webpages( this_iteration.query, - conversation_history, + construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage), location, user, send_status_func, @@ -258,8 +258,8 @@ async def execute_information_collection( try: async for result in run_code( this_iteration.query, - conversation_history, - previous_iterations_history, + construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage), + "", location, user, send_status_func, @@ -286,7 +286,7 @@ async def execute_information_collection( this_iteration.query, user, file_filters, - conversation_history, + construct_tool_chat_history(previous_iterations), query_images=query_images, agent=agent, send_status_func=send_status_func,