From bc95a99fb4fb3937751859e2824cb2a1081a07a1 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 9 Nov 2024 18:22:46 -0800 Subject: [PATCH] Make tracer the last input parameter for all the relevant chat helper methods --- .../conversation/anthropic/anthropic_chat.py | 4 +- .../conversation/google/gemini_chat.py | 5 +-- .../conversation/offline/chat_model.py | 4 +- src/khoj/processor/conversation/openai/gpt.py | 4 +- src/khoj/processor/conversation/utils.py | 2 +- src/khoj/processor/image/generate.py | 4 +- src/khoj/processor/tools/online_search.py | 6 +-- src/khoj/processor/tools/run_code.py | 2 +- src/khoj/routers/api.py | 10 ++--- src/khoj/routers/api_chat.py | 28 ++++++------- src/khoj/routers/helpers.py | 42 +++++++++---------- src/khoj/routers/research.py | 2 +- 12 files changed, 56 insertions(+), 57 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 1d139604..6989f4c1 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -36,8 +36,8 @@ def extract_questions_anthropic( query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -154,8 +154,8 @@ def converse_anthropic( agent: Agent = None, query_images: Optional[list[str]] = None, vision_available: bool = False, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): """ Converse with user using Anthropic's Claude diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 6d257faa..e4de609f 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -37,8 +37,8 @@ def extract_questions_gemini( query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -122,7 +122,6 @@ def gemini_send_message_to_model( temperature=0, model_kwargs=None, tracer={}, - attached_files: str = None, ): """ Send message to model @@ -165,8 +164,8 @@ def converse_gemini( agent: Agent = None, query_images: Optional[list[str]] = None, vision_available: bool = False, - tracer={}, attached_files: str = None, + tracer={}, ): """ Converse with user using Google's Gemini diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index d0b62f3d..6a25e258 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -37,8 +37,8 @@ def extract_questions_offline( max_prompt_size: int = None, temperature: float = 0.7, personality_context: Optional[str] = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ) -> List[str]: """ Infer search queries to retrieve relevant notes to answer user query @@ -154,8 +154,8 @@ def converse_offline( location_data: LocationData = None, user_name: str = None, agent: Agent = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ) -> Union[ThreadedGenerator, Iterator[str]]: """ Converse with user using Llama diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 65cdfa3f..f2919afb 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -34,8 +34,8 @@ def extract_questions( query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -154,8 +154,8 @@ def converse( agent: Agent = None, query_images: Optional[list[str]] = None, vision_available: bool = False, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): """ Converse with user using OpenAI's ChatGPT diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 791a98e0..248afa81 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -224,8 +224,8 @@ def save_to_conversation_log( automation_id: str = None, query_images: List[str] = None, raw_attached_files: List[FileAttachment] = [], - tracer: Dict[str, Any] = {}, train_of_thought: List[Any] = [], + tracer: Dict[str, Any] = {}, ): user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") turn_id = tracer.get("mid") or str(uuid.uuid4()) diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index ec5254ec..003dae4d 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -28,8 +28,8 @@ async def text_to_image( send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, agent: Agent = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): status_code = 200 image = None @@ -70,8 +70,8 @@ async def text_to_image( query_images=query_images, user=user, agent=agent, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) if send_status_func: diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 3b4bd16a..6bd14976 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -67,8 +67,8 @@ async def search_online( max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ, query_images: List[str] = None, agent: Agent = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): query += " ".join(custom_filters) if not is_internet_connected(): @@ -165,9 +165,9 @@ async def read_webpages( send_status_func: Optional[Callable] = None, query_images: List[str] = None, agent: Agent = None, - tracer: dict = {}, max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ, attached_files: str = None, + tracer: dict = {}, ): "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") @@ -178,8 +178,8 @@ async def read_webpages( user, query_images, agent=agent, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) # Get the top 10 web pages to read diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 418ab3a2..d5770ca0 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -31,8 +31,8 @@ async def run_code( query_images: List[str] = None, agent: Agent = None, sandbox_url: str = SANDBOX_URL, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): # Generate Code if send_status_func: diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 5474497d..ec3ae759 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -350,8 +350,8 @@ async def extract_references_and_questions( send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, agent: Agent = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): user = request.user.object if request.user.is_authenticated else None @@ -425,8 +425,8 @@ async def extract_references_and_questions( user=user, max_prompt_size=conversation_config.max_prompt_size, personality_context=personality_context, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = conversation_config.openai_config @@ -444,8 +444,8 @@ async def extract_references_and_questions( query_images=query_images, vision_enabled=vision_enabled, personality_context=personality_context, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key @@ -460,8 +460,8 @@ async def extract_references_and_questions( user=user, vision_enabled=vision_enabled, personality_context=personality_context, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -477,8 +477,8 @@ async def extract_references_and_questions( user=user, vision_enabled=vision_enabled, personality_context=personality_context, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) # Collate search results as context for GPT diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 390223d9..e9a844b0 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -756,8 +756,8 @@ async def chat( user=user, query_images=uploaded_images, agent=agent, - tracer=tracer, attached_files=attached_file_context, + tracer=tracer, ) # If we're doing research, we don't want to do anything else @@ -797,8 +797,8 @@ async def chat( user_name=user_name, location=location, file_filters=conversation.file_filters if conversation else [], - tracer=tracer, attached_files=attached_file_context, + tracer=tracer, ): if isinstance(research_result, InformationCollectionIteration): if research_result.summarizedResult: @@ -849,8 +849,8 @@ async def chat( query_images=uploaded_images, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), - tracer=tracer, attached_files=attached_file_context, + tracer=tracer, ): if isinstance(response, dict) and ChatEvent.STATUS in response: yield response[ChatEvent.STATUS] @@ -870,9 +870,9 @@ async def chat( client_application=request.user.client_app, conversation_id=conversation_id, query_images=uploaded_images, - tracer=tracer, train_of_thought=train_of_thought, raw_attached_files=raw_attached_files, + tracer=tracer, ) return @@ -916,9 +916,9 @@ async def chat( inferred_queries=[query_to_run], automation_id=automation.id, query_images=uploaded_images, - tracer=tracer, train_of_thought=train_of_thought, raw_attached_files=raw_attached_files, + tracer=tracer, ) async for result in send_llm_response(llm_response): yield result @@ -940,8 +940,8 @@ async def chat( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, - tracer=tracer, attached_files=attached_file_context, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -986,8 +986,8 @@ async def chat( custom_filters, query_images=uploaded_images, agent=agent, - tracer=tracer, attached_files=attached_file_context, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1012,8 +1012,8 @@ async def chat( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, - tracer=tracer, attached_files=attached_file_context, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1053,8 +1053,8 @@ async def chat( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, - tracer=tracer, attached_files=attached_file_context, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1093,8 +1093,8 @@ async def chat( send_status_func=partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, - tracer=tracer, attached_files=attached_file_context, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1125,10 +1125,10 @@ async def chat( compiled_references=compiled_references, online_results=online_results, query_images=uploaded_images, - tracer=tracer, train_of_thought=train_of_thought, attached_file_context=attached_file_context, raw_attached_files=raw_attached_files, + tracer=tracer, ) content_obj = { "intentType": intent_type, @@ -1157,8 +1157,8 @@ async def chat( user=user, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), - tracer=tracer, attached_files=attached_file_context, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1186,10 +1186,10 @@ async def chat( compiled_references=compiled_references, online_results=online_results, query_images=uploaded_images, - tracer=tracer, train_of_thought=train_of_thought, attached_file_context=attached_file_context, raw_attached_files=raw_attached_files, + tracer=tracer, ) async for result in send_llm_response(json.dumps(content_obj)): @@ -1215,10 +1215,10 @@ async def chat( user_name, researched_results, uploaded_images, - tracer, train_of_thought, attached_file_context, raw_attached_files, + tracer, ) # Send Response diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index ea5cca71..aed76ad1 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -361,8 +361,8 @@ async def aget_relevant_information_sources( user: KhojUser, query_images: List[str] = None, agent: Agent = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. @@ -399,8 +399,8 @@ async def aget_relevant_information_sources( relevant_tools_prompt, response_type="json_object", user=user, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) try: @@ -509,8 +509,8 @@ async def infer_webpage_urls( user: KhojUser, query_images: List[str] = None, agent: Agent = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ) -> List[str]: """ Infer webpage links from the given query @@ -539,8 +539,8 @@ async def infer_webpage_urls( query_images=query_images, response_type="json_object", user=user, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) # Validate that the response is a non-empty, JSON-serializable list of URLs @@ -565,8 +565,8 @@ async def generate_online_subqueries( user: KhojUser, query_images: List[str] = None, agent: Agent = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ) -> List[str]: """ Generate subqueries from the given query @@ -595,8 +595,8 @@ async def generate_online_subqueries( query_images=query_images, response_type="json_object", user=user, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) # Validate that the response is a non-empty, JSON-serializable list @@ -718,8 +718,8 @@ async def generate_summary_from_files( query_images: List[str] = None, agent: Agent = None, send_status_func: Optional[Callable] = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): try: file_objects = None @@ -781,8 +781,8 @@ async def generate_excalidraw_diagram( user: KhojUser = None, agent: Agent = None, send_status_func: Optional[Callable] = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): if send_status_func: async for event in send_status_func("**Enhancing the Diagramming Prompt**"): @@ -797,8 +797,8 @@ async def generate_excalidraw_diagram( query_images=query_images, user=user, agent=agent, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) if send_status_func: @@ -824,8 +824,8 @@ async def generate_better_diagram_description( query_images: List[str] = None, user: KhojUser = None, agent: Agent = None, - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ) -> str: """ Generate a diagram description from the given query and context @@ -866,8 +866,8 @@ async def generate_better_diagram_description( improve_diagram_description_prompt, query_images=query_images, user=user, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): @@ -914,8 +914,8 @@ async def generate_better_image_prompt( query_images: Optional[List[str]] = None, user: KhojUser = None, agent: Agent = None, - tracer: dict = {}, attached_files: str = "", + tracer: dict = {}, ) -> str: """ Generate a better image prompt from the given query @@ -963,7 +963,7 @@ async def generate_better_image_prompt( with timer("Chat actor: Generate contextual image prompt", logger): response = await send_message_to_model_wrapper( - image_prompt, query_images=query_images, user=user, tracer=tracer, attached_files=attached_files + image_prompt, query_images=query_images, user=user, attached_files=attached_files, tracer=tracer ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): @@ -979,8 +979,8 @@ async def send_message_to_model_wrapper( user: KhojUser = None, query_images: List[str] = None, context: str = "", - tracer: dict = {}, attached_files: str = None, + tracer: dict = {}, ): conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) vision_available = conversation_config.vision_enabled @@ -1106,8 +1106,8 @@ def send_message_to_model_wrapper_sync( system_message: str = "", response_type: str = "text", user: KhojUser = None, - tracer: dict = {}, attached_files: str = "", + tracer: dict = {}, ): conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) @@ -1225,10 +1225,10 @@ def generate_chat_response( user_name: Optional[str] = None, meta_research: str = "", query_images: Optional[List[str]] = None, - tracer: dict = {}, train_of_thought: List[Any] = [], attached_files: str = None, raw_attached_files: List[FileAttachment] = None, + tracer: dict = {}, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables chat_response = None @@ -1252,9 +1252,9 @@ def generate_chat_response( client_application=client_application, conversation_id=conversation_id, query_images=query_images, - tracer=tracer, train_of_thought=train_of_thought, raw_attached_files=raw_attached_files, + tracer=tracer, ) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) @@ -1281,8 +1281,8 @@ def generate_chat_response( location_data=location_data, user_name=user_name, agent=agent, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: @@ -1307,8 +1307,8 @@ def generate_chat_response( user_name=user_name, agent=agent, vision_available=vision_available, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: @@ -1330,8 +1330,8 @@ def generate_chat_response( user_name=user_name, agent=agent, vision_available=vision_available, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -1352,8 +1352,8 @@ def generate_chat_response( agent=agent, query_images=query_images, vision_available=vision_available, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ) metadata.update({"chat_model": conversation_config.chat_model}) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index dc34009c..c7755b0a 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -268,8 +268,8 @@ async def execute_information_collection( send_status_func, query_images=query_images, agent=agent, - tracer=tracer, attached_files=attached_files, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS]