From f867d5ed72b87787b1c72d3ddb3c16d2b2a0a5b5 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 15:54:25 -0700 Subject: [PATCH 01/88] Working prototype of meta-level chain of reasoning and execution - Create a more dynamic reasoning agent that can evaluate information and understand what it doesn't know, making moves to get that information - Lots of hacks and code that needs to be reversed later on before submission --- src/khoj/processor/conversation/prompts.py | 41 + src/khoj/routers/api.py | 7 +- src/khoj/routers/api_chat.py | 1055 ++++++++++---------- src/khoj/routers/helpers.py | 66 +- src/khoj/routers/research.py | 261 +++++ src/khoj/utils/helpers.py | 7 + 6 files changed, 906 insertions(+), 531 deletions(-) create mode 100644 src/khoj/routers/research.py diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index fa62dbb2..16a9ff67 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -485,6 +485,47 @@ Khoj: """.strip() ) +plan_function_execution = PromptTemplate.from_template( + """ +You are Khoj, an extremely smart and helpful search assistant. +{personality_context} +- You have access to a variety of data sources to help you answer the user's question +- You can use the data sources listed below to collect more relevant information, one at a time +- You are given multiple iterations to with these data sources to answer the user's question +- You are provided with additional context. If you have enough context to answer the question, then exit execution + +If you already know the answer to the question, return an empty response, e.g., {{}}. + +Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources: + +{tools} + +Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data source and associated query in a JSON object. Do not say anything else. + +Previous Iterations: +{previous_iterations} + +Response format: +{{"data_source": "", "query": ""}} + +Chat History: +{chat_history} + +Q: {query} +Khoj: +""".strip() +) + +previous_iteration = PromptTemplate.from_template( + """ +data_source: {data_source} +query: {query} +context: {context} +onlineContext: {onlineContext} +--- +""".strip() +) + pick_relevant_information_collection_tools = PromptTemplate.from_template( """ You are Khoj, an extremely smart and helpful search assistant. diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 11ab1112..d26b7b5a 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -355,9 +355,10 @@ async def extract_references_and_questions( agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent) if ( - not ConversationCommand.Notes in conversation_commands - and not ConversationCommand.Default in conversation_commands - and not agent_has_entries + # not ConversationCommand.Notes in conversation_commands + # and not ConversationCommand.Default in conversation_commands + # and not agent_has_entries + True ): yield compiled_references, inferred_queries, q return diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 4acefe30..af19a40c 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -41,7 +41,8 @@ from khoj.routers.helpers import ( aget_relevant_output_modes, construct_automation_created_message, create_automation, - extract_relevant_summary, + extract_relevant_info, + generate_summary_from_files, get_conversation_command, is_query_empty, is_ready_to_chat, @@ -49,6 +50,10 @@ from khoj.routers.helpers import ( update_telemetry_state, validate_conversation_config, ) +from khoj.routers.research import ( + InformationCollectionIteration, + execute_information_collection, +) from khoj.routers.storage import upload_image_to_bucket from khoj.utils import state from khoj.utils.helpers import ( @@ -689,6 +694,522 @@ async def chat( meta_log = conversation.conversation_log is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] + pending_research = True + + researched_results = "" + online_results: Dict = dict() + ## Extract Document References + compiled_references, inferred_queries, defiltered_query = [], [], None + + if conversation_commands == [ConversationCommand.Default] or is_automated_task: + async for research_result in execute_information_collection( + request=request, + user=user, + query=q, + conversation_id=conversation_id, + conversation_history=meta_log, + subscribed=subscribed, + uploaded_image_url=uploaded_image_url, + agent=agent, + send_status_func=partial(send_event, ChatEvent.STATUS), + location=location, + file_filters=conversation.file_filters if conversation else [], + ): + if type(research_result) == InformationCollectionIteration: + pending_research = False + if research_result.onlineContext: + researched_results += str(research_result.onlineContext) + online_results.update(research_result.onlineContext) + + if research_result.context: + researched_results += str(research_result.context) + compiled_references.extend(research_result.context) + + else: + yield research_result + + researched_results = await extract_relevant_info(q, researched_results, agent) + + logger.info(f"Researched Results: {researched_results}") + + pending_research = False + + conversation_commands = await aget_relevant_information_sources( + q, + meta_log, + is_automated_task, + subscribed=subscribed, + uploaded_image_url=uploaded_image_url, + agent=agent, + ) + conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) + async for result in send_event( + ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}" + ): + yield result + + mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent) + async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): + yield result + if mode not in conversation_commands: + conversation_commands.append(mode) + + for cmd in conversation_commands: + await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) + q = q.replace(f"/{cmd.value}", "").strip() + + used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] + file_filters = conversation.file_filters if conversation else [] + # Skip trying to summarize if + if ( + # summarization intent was inferred + ConversationCommand.Summarize in conversation_commands + # and not triggered via slash command + and not used_slash_summarize + # but we can't actually summarize + and len(file_filters) != 1 + # not pending research + and not pending_research + ): + conversation_commands.remove(ConversationCommand.Summarize) + elif ConversationCommand.Summarize in conversation_commands and pending_research: + response_log = "" + agent_has_entries = await EntryAdapters.aagent_has_entries(agent) + if len(file_filters) == 0 and not agent_has_entries: + response_log = "No files selected for summarization. Please add files using the section on the left." + async for result in send_llm_response(response_log): + yield result + elif len(file_filters) > 1 and not agent_has_entries: + response_log = "Only one file can be selected for summarization." + async for result in send_llm_response(response_log): + yield result + else: + response_log = await generate_summary_from_files( + q=query, + user=user, + file_filters=file_filters, + meta_log=meta_log, + subscribed=subscribed, + send_status_func=partial(send_event, ChatEvent.STATUS), + send_response_func=partial(send_llm_response), + ) + await sync_to_async(save_to_conversation_log)( + q, + response_log, + user, + meta_log, + user_message_time, + intent_type="summarize", + client_application=request.user.client_app, + conversation_id=conversation_id, + uploaded_image_url=uploaded_image_url, + ) + return + + custom_filters = [] + if conversation_commands == [ConversationCommand.Help]: + if not q: + conversation_config = await ConversationAdapters.aget_user_conversation_config(user) + if conversation_config == None: + conversation_config = await ConversationAdapters.aget_default_conversation_config() + model_type = conversation_config.model_type + formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) + async for result in send_llm_response(formatted_help): + yield result + return + # Adding specification to search online specifically on khoj.dev pages. + custom_filters.append("site:khoj.dev") + conversation_commands.append(ConversationCommand.Online) + + if ConversationCommand.Automation in conversation_commands: + try: + automation, crontime, query_to_run, subject = await create_automation( + q, timezone, user, request.url, meta_log + ) + except Exception as e: + logger.error(f"Error scheduling task {q} for {user.email}: {e}") + error_message = f"Unable to create automation. Ensure the automation doesn't already exist." + async for result in send_llm_response(error_message): + yield result + return + + llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) + await sync_to_async(save_to_conversation_log)( + q, + llm_response, + user, + meta_log, + user_message_time, + intent_type="automation", + client_application=request.user.client_app, + conversation_id=conversation_id, + inferred_queries=[query_to_run], + automation_id=automation.id, + uploaded_image_url=uploaded_image_url, + ) + async for result in send_llm_response(llm_response): + yield result + return + + # Gather Context + async for result in extract_references_and_questions( + request, + meta_log, + q, + (n or 7), + d, + conversation_id, + conversation_commands, + location, + partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + + if not is_none_or_empty(compiled_references): + headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) + # Strip only leading # from headings + headings = headings.replace("#", "") + async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): + yield result + + if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): + async for result in send_llm_response(f"{no_entries_found.format()}"): + yield result + return + + if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): + conversation_commands.remove(ConversationCommand.Notes) + + ## Gather Online References + if ConversationCommand.Online in conversation_commands and pending_research: + try: + async for result in search_online( + defiltered_query, + meta_log, + location, + user, + subscribed, + partial(send_event, ChatEvent.STATUS), + custom_filters, + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + online_results = result + except ValueError as e: + error_message = f"Error searching online: {e}. Attempting to respond without online results" + logger.warning(error_message) + async for result in send_llm_response(error_message): + yield result + return + + ## Gather Webpage References + if ConversationCommand.Webpage in conversation_commands and pending_research: + try: + async for result in read_webpages( + defiltered_query, + meta_log, + location, + user, + subscribed, + partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + direct_web_pages = result + webpages = [] + for query in direct_web_pages: + if online_results.get(query): + online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + else: + online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + + for webpage in direct_web_pages[query]["webpages"]: + webpages.append(webpage["link"]) + async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): + yield result + except ValueError as e: + logger.warning( + f"Error directly reading webpages: {e}. Attempting to respond without online results", + exc_info=True, + ) + + ## Send Gathered References + async for result in send_event( + ChatEvent.REFERENCES, + { + "inferredQueries": inferred_queries, + "context": compiled_references, + "onlineContext": online_results, + }, + ): + yield result + + # Generate Output + ## Generate Image Output + if ConversationCommand.Image in conversation_commands: + async for result in text_to_image( + q, + user, + meta_log, + location_data=location, + references=compiled_references, + online_results=online_results, + subscribed=subscribed, + send_status_func=partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + image, status_code, improved_image_prompt, intent_type = result + + if image is None or status_code != 200: + content_obj = { + "content-type": "application/json", + "intentType": intent_type, + "detail": improved_image_prompt, + "image": image, + } + async for result in send_llm_response(json.dumps(content_obj)): + yield result + return + + await sync_to_async(save_to_conversation_log)( + q, + image, + user, + meta_log, + user_message_time, + intent_type=intent_type, + inferred_queries=[improved_image_prompt], + client_application=request.user.client_app, + conversation_id=conversation_id, + compiled_references=compiled_references, + online_results=online_results, + uploaded_image_url=uploaded_image_url, + ) + content_obj = { + "intentType": intent_type, + "inferredQueries": [improved_image_prompt], + "image": image, + } + async for result in send_llm_response(json.dumps(content_obj)): + yield result + return + + ## Generate Text Output + async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): + yield result + llm_response, chat_metadata = await agenerate_chat_response( + defiltered_query, + meta_log, + conversation, + researched_results, + compiled_references, + online_results, + inferred_queries, + conversation_commands, + user, + request.user.client_app, + conversation_id, + location, + user_name, + uploaded_image_url, + ) + + # Send Response + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + + continue_stream = True + iterator = AsyncIteratorWrapper(llm_response) + async for item in iterator: + if item is None: + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + logger.debug("Finished streaming response") + return + if not connection_alive or not continue_stream: + continue + try: + async for result in send_event(ChatEvent.MESSAGE, f"{item}"): + yield result + except Exception as e: + continue_stream = False + logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") + + ## Stream Text Response + if stream: + return StreamingResponse(event_generator(q, image=image), media_type="text/plain") + ## Non-Streaming Text Response + else: + response_iterator = event_generator(q, image=image) + response_data = await read_chat_stream(response_iterator) + return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) + + +# @api_chat.post("") +@requires(["authenticated"]) +async def old_chat( + request: Request, + common: CommonQueryParams, + body: ChatRequestBody, + rate_limiter_per_minute=Depends( + ApiUserRateLimiter(requests=60, subscribed_requests=200, window=60, slug="chat_minute") + ), + rate_limiter_per_day=Depends( + ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day") + ), +): + # Access the parameters from the body + q = body.q + n = body.n + d = body.d + stream = body.stream + title = body.title + conversation_id = body.conversation_id + city = body.city + region = body.region + country = body.country or get_country_name_from_timezone(body.timezone) + country_code = body.country_code or get_country_code_from_timezone(body.timezone) + timezone = body.timezone + image = body.image + + async def event_generator(q: str, image: str): + start_time = time.perf_counter() + ttft = None + chat_metadata: dict = {} + connection_alive = True + user: KhojUser = request.user.object + subscribed: bool = has_required_scope(request, ["premium"]) + event_delimiter = "␃🔚␗" + q = unquote(q) + nonlocal conversation_id + + uploaded_image_url = None + if image: + decoded_string = unquote(image) + base64_data = decoded_string.split(",", 1)[1] + image_bytes = base64.b64decode(base64_data) + webp_image_bytes = convert_image_to_webp(image_bytes) + try: + uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id) + except: + uploaded_image_url = None + + async def send_event(event_type: ChatEvent, data: str | dict): + nonlocal connection_alive, ttft + if not connection_alive or await request.is_disconnected(): + connection_alive = False + logger.warning(f"User {user} disconnected from {common.client} client") + return + try: + if event_type == ChatEvent.END_LLM_RESPONSE: + collect_telemetry() + if event_type == ChatEvent.START_LLM_RESPONSE: + ttft = time.perf_counter() - start_time + if event_type == ChatEvent.MESSAGE: + yield data + elif event_type == ChatEvent.REFERENCES or stream: + yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) + except asyncio.CancelledError as e: + connection_alive = False + logger.warn(f"User {user} disconnected from {common.client} client: {e}") + return + except Exception as e: + connection_alive = False + logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) + return + finally: + yield event_delimiter + + async def send_llm_response(response: str): + async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): + yield result + async for result in send_event(ChatEvent.MESSAGE, response): + yield result + async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): + yield result + + def collect_telemetry(): + # Gather chat response telemetry + nonlocal chat_metadata + latency = time.perf_counter() - start_time + cmd_set = set([cmd.value for cmd in conversation_commands]) + chat_metadata = chat_metadata or {} + chat_metadata["conversation_command"] = cmd_set + chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None + chat_metadata["latency"] = f"{latency:.3f}" + chat_metadata["ttft_latency"] = f"{ttft:.3f}" + + logger.info(f"Chat response time to first token: {ttft:.3f} seconds") + logger.info(f"Chat response total time: {latency:.3f} seconds") + update_telemetry_state( + request=request, + telemetry_type="api", + api="chat", + client=request.user.client_app, + user_agent=request.headers.get("user-agent"), + host=request.headers.get("host"), + metadata=chat_metadata, + ) + + conversation_commands = [get_conversation_command(query=q, any_references=True)] + + conversation = await ConversationAdapters.aget_conversation_by_user( + user, + client_application=request.user.client_app, + conversation_id=conversation_id, + title=title, + create_new=body.create_new, + ) + if not conversation: + async for result in send_llm_response(f"Conversation {conversation_id} not found"): + yield result + return + conversation_id = conversation.id + + agent: Agent | None = None + default_agent = await AgentAdapters.aget_default_agent() + if conversation.agent and conversation.agent != default_agent: + agent = conversation.agent + + if not conversation.agent: + conversation.agent = default_agent + await conversation.asave() + agent = default_agent + + await is_ready_to_chat(user) + + user_name = await aget_user_name(user) + location = None + if city or region or country or country_code: + location = LocationData(city=city, region=region, country=country, country_code=country_code) + + if is_query_empty(q): + async for result in send_llm_response("Please ask your query to get started."): + yield result + return + + user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + meta_log = conversation.conversation_log + is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] + if conversation_commands == [ConversationCommand.Default] or is_automated_task: conversation_commands = await aget_relevant_information_sources( q, @@ -738,47 +1259,15 @@ async def chat( async for result in send_llm_response(response_log): yield result else: - try: - file_object = None - if await EntryAdapters.aagent_has_entries(agent): - file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) - if len(file_names) > 0: - file_object = await FileObjectAdapters.async_get_file_objects_by_name( - None, file_names[0], agent - ) - - if len(file_filters) > 0: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - - if len(file_object) == 0: - response_log = "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again." - async for result in send_llm_response(response_log): - yield result - return - contextual_data = " ".join([file.raw_text for file in file_object]) - if not q: - q = "Create a general summary of the file" - async for result in send_event( - ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}" - ): - yield result - - response = await extract_relevant_summary( - q, - contextual_data, - conversation_history=meta_log, - subscribed=subscribed, - uploaded_image_url=uploaded_image_url, - agent=agent, - ) - response_log = str(response) - async for result in send_llm_response(response_log): - yield result - except Exception as e: - response_log = "Error summarizing file. Please try again, or contact support." - logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - async for result in send_llm_response(response_log): - yield result + response_log = await generate_summary_from_files( + q=query, + user=user, + file_filters=file_filters, + meta_log=meta_log, + subscribed=subscribed, + send_status_func=partial(send_event, ChatEvent.STATUS), + send_response_func=partial(send_llm_response), + ) await sync_to_async(save_to_conversation_log)( q, response_log, @@ -867,8 +1356,6 @@ async def chat( async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): yield result - online_results: Dict = dict() - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): async for result in send_llm_response(f"{no_entries_found.format()}"): yield result @@ -1049,483 +1536,3 @@ async def chat( response_iterator = event_generator(q, image=image) response_data = await read_chat_stream(response_iterator) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) - - -# Deprecated API. Remove by end of September 2024 -@api_chat.get("") -@requires(["authenticated"]) -async def get_chat( - request: Request, - common: CommonQueryParams, - q: str, - n: int = 7, - d: float = None, - stream: Optional[bool] = False, - title: Optional[str] = None, - conversation_id: Optional[str] = None, - city: Optional[str] = None, - region: Optional[str] = None, - country: Optional[str] = None, - timezone: Optional[str] = None, - image: Optional[str] = None, - rate_limiter_per_minute=Depends( - ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute") - ), - rate_limiter_per_day=Depends( - ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - ), -): - # Issue a deprecation warning - warnings.warn( - "The 'get_chat' API endpoint is deprecated. It will be removed by the end of September 2024.", - DeprecationWarning, - stacklevel=2, - ) - - async def event_generator(q: str, image: str): - start_time = time.perf_counter() - ttft = None - chat_metadata: dict = {} - connection_alive = True - user: KhojUser = request.user.object - subscribed: bool = has_required_scope(request, ["premium"]) - event_delimiter = "␃🔚␗" - q = unquote(q) - nonlocal conversation_id - - uploaded_image_url = None - if image: - decoded_string = unquote(image) - base64_data = decoded_string.split(",", 1)[1] - image_bytes = base64.b64decode(base64_data) - webp_image_bytes = convert_image_to_webp(image_bytes) - try: - uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id) - except: - uploaded_image_url = None - - async def send_event(event_type: ChatEvent, data: str | dict): - nonlocal connection_alive, ttft - if not connection_alive or await request.is_disconnected(): - connection_alive = False - logger.warn(f"User {user} disconnected from {common.client} client") - return - try: - if event_type == ChatEvent.END_LLM_RESPONSE: - collect_telemetry() - if event_type == ChatEvent.START_LLM_RESPONSE: - ttft = time.perf_counter() - start_time - if event_type == ChatEvent.MESSAGE: - yield data - elif event_type == ChatEvent.REFERENCES or stream: - yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) - except asyncio.CancelledError as e: - connection_alive = False - logger.warn(f"User {user} disconnected from {common.client} client: {e}") - return - except Exception as e: - connection_alive = False - logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) - return - finally: - yield event_delimiter - - async def send_llm_response(response: str): - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result - async for result in send_event(ChatEvent.MESSAGE, response): - yield result - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - - def collect_telemetry(): - # Gather chat response telemetry - nonlocal chat_metadata - latency = time.perf_counter() - start_time - cmd_set = set([cmd.value for cmd in conversation_commands]) - chat_metadata = chat_metadata or {} - chat_metadata["conversation_command"] = cmd_set - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - chat_metadata["latency"] = f"{latency:.3f}" - chat_metadata["ttft_latency"] = f"{ttft:.3f}" - - logger.info(f"Chat response time to first token: {ttft:.3f} seconds") - logger.info(f"Chat response total time: {latency:.3f} seconds") - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - client=request.user.client_app, - user_agent=request.headers.get("user-agent"), - host=request.headers.get("host"), - metadata=chat_metadata, - ) - - conversation_commands = [get_conversation_command(query=q, any_references=True)] - - conversation = await ConversationAdapters.aget_conversation_by_user( - user, client_application=request.user.client_app, conversation_id=conversation_id, title=title - ) - if not conversation: - async for result in send_llm_response(f"Conversation {conversation_id} not found"): - yield result - return - conversation_id = conversation.id - agent = conversation.agent if conversation.agent else None - - await is_ready_to_chat(user) - - user_name = await aget_user_name(user) - location = None - if city or region or country: - location = LocationData(city=city, region=region, country=country) - - if is_query_empty(q): - async for result in send_llm_response("Please ask your query to get started."): - yield result - return - - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - meta_log = conversation.conversation_log - is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - - if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_information_sources( - q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url - ) - conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) - async for result in send_event( - ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}" - ): - yield result - - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url) - async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): - yield result - if mode not in conversation_commands: - conversation_commands.append(mode) - - for cmd in conversation_commands: - await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - - used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] - file_filters = conversation.file_filters if conversation else [] - # Skip trying to summarize if - if ( - # summarization intent was inferred - ConversationCommand.Summarize in conversation_commands - # and not triggered via slash command - and not used_slash_summarize - # but we can't actually summarize - and len(file_filters) != 1 - ): - conversation_commands.remove(ConversationCommand.Summarize) - elif ConversationCommand.Summarize in conversation_commands: - response_log = "" - if len(file_filters) == 0: - response_log = "No files selected for summarization. Please add files using the section on the left." - async for result in send_llm_response(response_log): - yield result - elif len(file_filters) > 1: - response_log = "Only one file can be selected for summarization." - async for result in send_llm_response(response_log): - yield result - else: - try: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - if len(file_object) == 0: - response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - async for result in send_llm_response(response_log): - yield result - return - contextual_data = " ".join([file.raw_text for file in file_object]) - if not q: - q = "Create a general summary of the file" - async for result in send_event( - ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}" - ): - yield result - - response = await extract_relevant_summary( - q, - contextual_data, - conversation_history=meta_log, - subscribed=subscribed, - uploaded_image_url=uploaded_image_url, - ) - response_log = str(response) - async for result in send_llm_response(response_log): - yield result - except Exception as e: - response_log = "Error summarizing file." - logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - async for result in send_llm_response(response_log): - yield result - await sync_to_async(save_to_conversation_log)( - q, - response_log, - user, - meta_log, - user_message_time, - intent_type="summarize", - client_application=request.user.client_app, - conversation_id=conversation_id, - uploaded_image_url=uploaded_image_url, - ) - return - - custom_filters = [] - if conversation_commands == [ConversationCommand.Help]: - if not q: - conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() - model_type = conversation_config.model_type - formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) - async for result in send_llm_response(formatted_help): - yield result - return - # Adding specification to search online specifically on khoj.dev pages. - custom_filters.append("site:khoj.dev") - conversation_commands.append(ConversationCommand.Online) - - if ConversationCommand.Automation in conversation_commands: - try: - automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log - ) - except Exception as e: - logger.error(f"Error scheduling task {q} for {user.email}: {e}") - error_message = f"Unable to create automation. Ensure the automation doesn't already exist." - async for result in send_llm_response(error_message): - yield result - return - - llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - meta_log, - user_message_time, - intent_type="automation", - client_application=request.user.client_app, - conversation_id=conversation_id, - inferred_queries=[query_to_run], - automation_id=automation.id, - uploaded_image_url=uploaded_image_url, - ) - async for result in send_llm_response(llm_response): - yield result - return - - # Gather Context - ## Extract Document References - compiled_references, inferred_queries, defiltered_query = [], [], None - async for result in extract_references_and_questions( - request, - meta_log, - q, - (n or 7), - d, - conversation_id, - conversation_commands, - location, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] - - if not is_none_or_empty(compiled_references): - headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - # Strip only leading # from headings - headings = headings.replace("#", "") - async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): - yield result - - online_results: Dict = dict() - - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - async for result in send_llm_response(f"{no_entries_found.format()}"): - yield result - return - - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) - - ## Gather Online References - if ConversationCommand.Online in conversation_commands: - try: - async for result in search_online( - defiltered_query, - meta_log, - location, - user, - subscribed, - partial(send_event, ChatEvent.STATUS), - custom_filters, - uploaded_image_url=uploaded_image_url, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - online_results = result - except ValueError as e: - error_message = f"Error searching online: {e}. Attempting to respond without online results" - logger.warning(error_message) - async for result in send_llm_response(error_message): - yield result - return - - ## Gather Webpage References - if ConversationCommand.Webpage in conversation_commands: - try: - async for result in read_webpages( - defiltered_query, - meta_log, - location, - user, - subscribed, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - direct_web_pages = result - webpages = [] - for query in direct_web_pages: - if online_results.get(query): - online_results[query]["webpages"] = direct_web_pages[query]["webpages"] - else: - online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} - - for webpage in direct_web_pages[query]["webpages"]: - webpages.append(webpage["link"]) - async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): - yield result - except ValueError as e: - logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", - exc_info=True, - ) - - ## Send Gathered References - async for result in send_event( - ChatEvent.REFERENCES, - { - "inferredQueries": inferred_queries, - "context": compiled_references, - "onlineContext": online_results, - }, - ): - yield result - - # Generate Output - ## Generate Image Output - if ConversationCommand.Image in conversation_commands: - async for result in text_to_image( - q, - user, - meta_log, - location_data=location, - references=compiled_references, - online_results=online_results, - subscribed=subscribed, - send_status_func=partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - image, status_code, improved_image_prompt, intent_type = result - - if image is None or status_code != 200: - content_obj = { - "content-type": "application/json", - "intentType": intent_type, - "detail": improved_image_prompt, - "image": image, - } - async for result in send_llm_response(json.dumps(content_obj)): - yield result - return - - await sync_to_async(save_to_conversation_log)( - q, - image, - user, - meta_log, - user_message_time, - intent_type=intent_type, - inferred_queries=[improved_image_prompt], - client_application=request.user.client_app, - conversation_id=conversation_id, - compiled_references=compiled_references, - online_results=online_results, - uploaded_image_url=uploaded_image_url, - ) - content_obj = { - "intentType": intent_type, - "inferredQueries": [improved_image_prompt], - "image": image, - } - async for result in send_llm_response(json.dumps(content_obj)): - yield result - return - - ## Generate Text Output - async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): - yield result - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - meta_log, - conversation, - compiled_references, - online_results, - inferred_queries, - conversation_commands, - user, - request.user.client_app, - conversation_id, - location, - user_name, - uploaded_image_url, - ) - - # Send Response - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result - - continue_stream = True - iterator = AsyncIteratorWrapper(llm_response) - async for item in iterator: - if item is None: - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - logger.debug("Finished streaming response") - return - if not connection_alive or not continue_stream: - continue - try: - async for result in send_event(ChatEvent.MESSAGE, f"{item}"): - yield result - except Exception as e: - continue_stream = False - logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") - - ## Stream Text Response - if stream: - return StreamingResponse(event_generator(q, image=image), media_type="text/plain") - ## Non-Streaming Text Response - else: - response_iterator = event_generator(q, image=image) - response_data = await read_chat_stream(response_iterator) - return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index fdb1aa12..279ad85e 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -14,6 +14,7 @@ from typing import ( Annotated, Any, AsyncGenerator, + Callable, Dict, Iterator, List, @@ -39,6 +40,7 @@ from khoj.database.adapters import ( AutomationAdapters, ConversationAdapters, EntryAdapters, + FileObjectAdapters, create_khoj_token, get_khoj_tokens, get_user_name, @@ -614,6 +616,58 @@ async def extract_relevant_summary( return response.strip() +async def generate_summary_from_files( + q: str, + user: KhojUser, + file_filters: List[str], + meta_log: dict, + subscribed: bool, + uploaded_image_url: str = None, + agent: Agent = None, + send_status_func: Optional[Callable] = None, + send_response_func: Optional[Callable] = None, +): + try: + file_object = None + if await EntryAdapters.aagent_has_entries(agent): + file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) + if len(file_names) > 0: + file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names[0], agent) + + if len(file_filters) > 0: + file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) + + if len(file_object) == 0: + response_log = ( + "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again." + ) + async for result in send_response_func(response_log): + yield result + return + contextual_data = " ".join([file.raw_text for file in file_object]) + if not q: + q = "Create a general summary of the file" + async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"): + yield result + + response = await extract_relevant_summary( + q, + contextual_data, + conversation_history=meta_log, + subscribed=subscribed, + uploaded_image_url=uploaded_image_url, + agent=agent, + ) + response_log = str(response) + async for result in send_response_func(response_log): + yield result + except Exception as e: + response_log = "Error summarizing file. Please try again, or contact support." + logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) + async for result in send_response_func(response_log): + yield result + + async def generate_better_image_prompt( q: str, conversation_history: str, @@ -893,6 +947,7 @@ def generate_chat_response( q: str, meta_log: dict, conversation: Conversation, + meta_research: str = "", compiled_references: List[Dict] = [], online_results: Dict[str, Dict] = {}, inferred_queries: List[str] = [], @@ -910,6 +965,9 @@ def generate_chat_response( metadata = {} agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None + query_to_run = q + if meta_research: + query_to_run = f"AI Research: {meta_research} {q}" try: partial_completion = partial( save_to_conversation_log, @@ -937,7 +995,7 @@ def generate_chat_response( chat_response = converse_offline( references=compiled_references, online_results=online_results, - user_query=q, + user_query=query_to_run, loaded_model=loaded_model, conversation_log=meta_log, completion_func=partial_completion, @@ -956,7 +1014,7 @@ def generate_chat_response( chat_model = conversation_config.chat_model chat_response = converse( compiled_references, - q, + query_to_run, image_url=uploaded_image_url, online_results=online_results, conversation_log=meta_log, @@ -977,7 +1035,7 @@ def generate_chat_response( api_key = conversation_config.openai_config.api_key chat_response = converse_anthropic( compiled_references, - q, + query_to_run, online_results, meta_log, model=conversation_config.chat_model, @@ -994,7 +1052,7 @@ def generate_chat_response( api_key = conversation_config.openai_config.api_key chat_response = converse_gemini( compiled_references, - q, + query_to_run, online_results, meta_log, model=conversation_config.chat_model, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py new file mode 100644 index 00000000..65c3f42d --- /dev/null +++ b/src/khoj/routers/research.py @@ -0,0 +1,261 @@ +import json +import logging +from typing import Any, Callable, Dict, List, Optional + +from fastapi import Request + +from khoj.database.adapters import EntryAdapters +from khoj.database.models import Agent, KhojUser +from khoj.processor.conversation import prompts +from khoj.processor.conversation.utils import remove_json_codeblock +from khoj.processor.tools.online_search import read_webpages, search_online +from khoj.routers.api import extract_references_and_questions +from khoj.routers.helpers import ( + ChatEvent, + construct_chat_history, + generate_summary_from_files, + send_message_to_model_wrapper, +) +from khoj.utils.helpers import ( + ConversationCommand, + function_calling_description_for_llm, + timer, +) +from khoj.utils.rawconfig import LocationData + +logger = logging.getLogger(__name__) + + +class InformationCollectionIteration: + def __init__( + self, data_source: str, query: str, context: str = None, onlineContext: str = None, result: Any = None + ): + self.data_source = data_source + self.query = query + self.context = context + self.onlineContext = onlineContext + + +async def apick_next_tool( + query: str, + conversation_history: dict, + subscribed: bool, + uploaded_image_url: str = None, + agent: Agent = None, + previous_iterations: List[InformationCollectionIteration] = None, +): + """ + Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer. + """ + + tool_options = dict() + tool_options_str = "" + + agent_tools = agent.input_tools if agent else [] + + for tool, description in function_calling_description_for_llm.items(): + tool_options[tool.value] = description + if len(agent_tools) == 0 or tool.value in agent_tools: + tool_options_str += f'- "{tool.value}": "{description}"\n' + + chat_history = construct_chat_history(conversation_history) + + previous_iterations_history = "" + for iteration in previous_iterations: + iteration_data = prompts.previous_iteration.format( + query=iteration.query, + data_source=iteration.data_source, + context=str(iteration.context), + onlineContext=str(iteration.onlineContext), + ) + + previous_iterations_history += iteration_data + + if uploaded_image_url: + query = f"[placeholder for user attached image]\n{query}" + + personality_context = ( + prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" + ) + + function_planning_prompt = prompts.plan_function_execution.format( + query=query, + tools=tool_options_str, + chat_history=chat_history, + personality_context=personality_context, + previous_iterations=previous_iterations_history, + ) + + with timer("Chat actor: Infer information sources to refer", logger): + response = await send_message_to_model_wrapper( + function_planning_prompt, + response_type="json_object", + subscribed=subscribed, + ) + + try: + response = response.strip() + response = remove_json_codeblock(response) + response = json.loads(response) + suggested_data_source = response.get("data_source", None) + suggested_query = response.get("query", None) + + return InformationCollectionIteration( + data_source=suggested_data_source, + query=suggested_query, + ) + + except Exception as e: + logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True) + return InformationCollectionIteration( + data_source=None, + query=None, + ) + + +async def execute_information_collection( + request: Request, + user: KhojUser, + query: str, + conversation_id: str, + conversation_history: dict, + subscribed: bool, + uploaded_image_url: str = None, + agent: Agent = None, + send_status_func: Optional[Callable] = None, + location: LocationData = None, + file_filters: List[str] = [], +): + iteration = 0 + MAX_ITERATIONS = 2 + previous_iterations = [] + while iteration < MAX_ITERATIONS: + online_results: Dict = dict() + compiled_references, inferred_queries, defiltered_query = [], [], None + this_iteration = await apick_next_tool( + query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations + ) + if this_iteration.data_source == ConversationCommand.Notes: + ## Extract Document References + compiled_references, inferred_queries, defiltered_query = [], [], None + async for result in extract_references_and_questions( + request, + conversation_history, + this_iteration.query, + 7, + None, + conversation_id, + [ConversationCommand.Default], + location, + send_status_func, + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + previous_iterations.append( + InformationCollectionIteration( + data_source=this_iteration.data_source, + query=this_iteration.query, + context=str(compiled_references), + ) + ) + + elif this_iteration.data_source == ConversationCommand.Online: + async for result in search_online( + this_iteration.query, + conversation_history, + location, + user, + subscribed, + send_status_func, + [], + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + online_results = result + previous_iterations.append( + InformationCollectionIteration( + data_source=this_iteration.data_source, + query=this_iteration.query, + onlineContext=online_results, + ) + ) + + elif this_iteration.data_source == ConversationCommand.Webpage: + async for result in read_webpages( + this_iteration.query, + conversation_history, + location, + user, + subscribed, + send_status_func, + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + direct_web_pages = result + + webpages = [] + for query in direct_web_pages: + if online_results.get(query): + online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + else: + online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + + for webpage in direct_web_pages[query]["webpages"]: + webpages.append(webpage["link"]) + yield send_status_func(f"**Read web pages**: {webpages}") + + previous_iterations.append( + InformationCollectionIteration( + data_source=this_iteration.data_source, + query=this_iteration.query, + onlineContext=online_results, + ) + ) + + elif this_iteration.data_source == ConversationCommand.Summarize: + response_log = "" + agent_has_entries = await EntryAdapters.aagent_has_entries(agent) + if len(file_filters) == 0 and not agent_has_entries: + previous_iterations.append( + InformationCollectionIteration( + data_source=this_iteration.data_source, + query=this_iteration.query, + context="No files selected for summarization.", + ) + ) + elif len(file_filters) > 1 and not agent_has_entries: + response_log = "Only one file can be selected for summarization." + previous_iterations.append( + InformationCollectionIteration( + data_source=this_iteration.data_source, + query=this_iteration.query, + context=response_log, + ) + ) + else: + response_log = await generate_summary_from_files( + q=query, + user=user, + file_filters=file_filters, + meta_log=conversation_history, + subscribed=subscribed, + send_status_func=send_status_func, + ) + else: + iteration = MAX_ITERATIONS + + iteration += 1 + for completed_iter in previous_iterations: + yield completed_iter diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 4c7bf985..8538aace 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -345,6 +345,13 @@ tool_descriptions_for_llm = { ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.", } +function_calling_description_for_llm = { + ConversationCommand.Notes: "Use this if you think the user's personal knowledge base contains relevant context.", + ConversationCommand.Online: "Use this if you think the there's important information on the internet related to the query.", + ConversationCommand.Webpage: "Use this if the user has provided a webpage URL or you are share of a webpage URL that will help you directly answer this query", + ConversationCommand.Summarize: "Use this if you want to retrieve an answer that depends on reading an entire corpus.", +} + mode_descriptions_for_llm = { ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.", ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.", From c91678078d63c9354d25ea4b409962fb0f89fc10 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 15:55:55 -0700 Subject: [PATCH 02/88] Correct the usage of query passed to summarize function --- src/khoj/routers/api_chat.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index af19a40c..7b0241b5 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -695,7 +695,6 @@ async def chat( is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] pending_research = True - researched_results = "" online_results: Dict = dict() ## Extract Document References @@ -785,7 +784,7 @@ async def chat( yield result else: response_log = await generate_summary_from_files( - q=query, + q=q, user=user, file_filters=file_filters, meta_log=meta_log, @@ -1260,7 +1259,7 @@ async def old_chat( yield result else: response_log = await generate_summary_from_files( - q=query, + q=q, user=user, file_filters=file_filters, meta_log=meta_log, From 4fbaef10e93e2b89cabcee09a188cec392b31a74 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 15:58:05 -0700 Subject: [PATCH 03/88] Correct usage of the summarize function --- src/khoj/routers/api_chat.py | 6 ++++-- src/khoj/routers/research.py | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 7b0241b5..64c50a1c 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -783,7 +783,7 @@ async def chat( async for result in send_llm_response(response_log): yield result else: - response_log = await generate_summary_from_files( + async for response in generate_summary_from_files( q=q, user=user, file_filters=file_filters, @@ -791,7 +791,9 @@ async def chat( subscribed=subscribed, send_status_func=partial(send_event, ChatEvent.STATUS), send_response_func=partial(send_llm_response), - ) + ): + yield response + await sync_to_async(save_to_conversation_log)( q, response_log, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 65c3f42d..4211b072 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -245,14 +245,25 @@ async def execute_information_collection( ) ) else: - response_log = await generate_summary_from_files( + async for response in generate_summary_from_files( q=query, user=user, file_filters=file_filters, meta_log=conversation_history, subscribed=subscribed, send_status_func=send_status_func, - ) + ): + if isinstance(response, dict) and ChatEvent.STATUS in response: + yield response[ChatEvent.STATUS] + else: + response_log = response + previous_iterations.append( + InformationCollectionIteration( + data_source=this_iteration.data_source, + query=this_iteration.query, + context=response_log, + ) + ) else: iteration = MAX_ITERATIONS From 46ef205a754a6a40203eaacc4c3a3fd65f4c8df0 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 16:01:52 -0700 Subject: [PATCH 04/88] Add additional type annotations for compiled_references et al --- src/khoj/routers/research.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 4211b072..f4d92f51 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -128,10 +128,14 @@ async def execute_information_collection( ): iteration = 0 MAX_ITERATIONS = 2 - previous_iterations = [] + previous_iterations = List[InformationCollectionIteration] while iteration < MAX_ITERATIONS: online_results: Dict = dict() - compiled_references, inferred_queries, defiltered_query = [], [], None + + compiled_references: List[Any] = [] + inferred_queries: List[Any] = [] + defiltered_query = None + this_iteration = await apick_next_tool( query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations ) From 4978360852f83b35a9c6bf08d0049dd3178be2f4 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 16:02:41 -0700 Subject: [PATCH 05/88] Fix type of previous_iterations --- src/khoj/routers/research.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index f4d92f51..f55fc024 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -128,7 +128,7 @@ async def execute_information_collection( ): iteration = 0 MAX_ITERATIONS = 2 - previous_iterations = List[InformationCollectionIteration] + previous_iterations: List[InformationCollectionIteration] = [] while iteration < MAX_ITERATIONS: online_results: Dict = dict() From 6960fb097c08e18c26631dbc5e5a0a9ea1482014 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 16:04:39 -0700 Subject: [PATCH 06/88] update types of prev iterations response --- src/khoj/routers/research.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index f55fc024..7fb25498 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -27,9 +27,7 @@ logger = logging.getLogger(__name__) class InformationCollectionIteration: - def __init__( - self, data_source: str, query: str, context: str = None, onlineContext: str = None, result: Any = None - ): + def __init__(self, data_source: str, query: str, context: str = None, onlineContext: dict = None): self.data_source = data_source self.query = query self.context = context @@ -260,7 +258,7 @@ async def execute_information_collection( if isinstance(response, dict) and ChatEvent.STATUS in response: yield response[ChatEvent.STATUS] else: - response_log = response + response_log = response # type: ignore previous_iterations.append( InformationCollectionIteration( data_source=this_iteration.data_source, From f7e6f99a32f4bbf691b05db2f823e7e7db55126b Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 16:05:34 -0700 Subject: [PATCH 07/88] add typing for extract document references --- src/khoj/routers/api_chat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 64c50a1c..98598b36 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -6,7 +6,7 @@ import time import warnings from datetime import datetime from functools import partial -from typing import Dict, Optional +from typing import Any, Dict, List, Optional from urllib.parse import unquote from asgiref.sync import sync_to_async @@ -698,7 +698,9 @@ async def chat( researched_results = "" online_results: Dict = dict() ## Extract Document References - compiled_references, inferred_queries, defiltered_query = [], [], None + compiled_references: List[Any] = [] + inferred_queries: List[Any] = [] + defiltered_query: str = None if conversation_commands == [ConversationCommand.Default] or is_automated_task: async for research_result in execute_information_collection( From f71e4969d319143f880e3d03557ed954a207a687 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 16:40:06 -0700 Subject: [PATCH 08/88] Skip summarize while it's broken, and snip some other parts of the workflow while under construction --- src/khoj/processor/tools/online_search.py | 2 ++ src/khoj/routers/api_chat.py | 19 +++++++++---------- src/khoj/routers/research.py | 8 +++++++- src/khoj/utils/helpers.py | 1 - 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 393442c4..840d8a81 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -45,6 +45,8 @@ OLOSTEP_QUERY_PARAMS = { "expandMarkdown": "True", "expandHtml": "False", } + +# TODO: Should this be 0 to let advanced model decide which web pages to read? MAX_WEBPAGES_TO_READ = 1 diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 98598b36..f4061c50 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -743,11 +743,6 @@ async def chat( uploaded_image_url=uploaded_image_url, agent=agent, ) - conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) - async for result in send_event( - ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}" - ): - yield result mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent) async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): @@ -876,11 +871,15 @@ async def chat( defiltered_query = result[2] if not is_none_or_empty(compiled_references): - headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - # Strip only leading # from headings - headings = headings.replace("#", "") - async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): - yield result + try: + headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) + # Strip only leading # from headings + headings = headings.replace("#", "") + async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): + yield result + except Exception as e: + # TODO Get correct type for compiled across research notes extraction + logger.error(f"Error extracting references: {e}", exc_info=True) if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): async for result in send_llm_response(f"{no_entries_found.format()}"): diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 7fb25498..f6eae48d 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Optional from fastapi import Request -from khoj.database.adapters import EntryAdapters +from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import remove_json_codeblock @@ -76,6 +76,7 @@ async def apick_next_tool( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" ) + # TODO Add current date/time to the query function_planning_prompt = prompts.plan_function_execution.format( query=query, tools=tool_options_str, @@ -84,11 +85,14 @@ async def apick_next_tool( previous_iterations=previous_iterations_history, ) + chat_model_option = await ConversationAdapters.aget_advanced_conversation_config() + with timer("Chat actor: Infer information sources to refer", logger): response = await send_message_to_model_wrapper( function_planning_prompt, response_type="json_object", subscribed=subscribed, + chat_model_option=chat_model_option, ) try: @@ -98,6 +102,8 @@ async def apick_next_tool( suggested_data_source = response.get("data_source", None) suggested_query = response.get("query", None) + logger.info(f"Response for determining relevant tools: {response}") + return InformationCollectionIteration( data_source=suggested_data_source, query=suggested_query, diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 8538aace..25243a27 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -349,7 +349,6 @@ function_calling_description_for_llm = { ConversationCommand.Notes: "Use this if you think the user's personal knowledge base contains relevant context.", ConversationCommand.Online: "Use this if you think the there's important information on the internet related to the query.", ConversationCommand.Webpage: "Use this if the user has provided a webpage URL or you are share of a webpage URL that will help you directly answer this query", - ConversationCommand.Summarize: "Use this if you want to retrieve an answer that depends on reading an entire corpus.", } mode_descriptions_for_llm = { From 7b288a11793c21ba40e7738128db3285980ba786 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 16:59:20 -0700 Subject: [PATCH 09/88] Clean up the function planning prompt a little bit --- src/khoj/processor/conversation/prompts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 16a9ff67..1246a43a 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -487,7 +487,7 @@ Khoj: plan_function_execution = PromptTemplate.from_template( """ -You are Khoj, an extremely smart and helpful search assistant. +You are an extremely methodical planner. Your goal is to make a plan to execute a function based on the user's query. {personality_context} - You have access to a variety of data sources to help you answer the user's question - You can use the data sources listed below to collect more relevant information, one at a time @@ -512,7 +512,7 @@ Chat History: {chat_history} Q: {query} -Khoj: +Response: """.strip() ) From 5b8d663cf178372a75ce92d07bc85836f607fe5b Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 17:40:56 -0700 Subject: [PATCH 10/88] Add intermediate summarization of results when planning with o1 --- src/khoj/processor/conversation/prompts.py | 9 +- src/khoj/routers/api_chat.py | 16 +-- src/khoj/routers/research.py | 137 +++++++++++---------- 3 files changed, 84 insertions(+), 78 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 1246a43a..6ef1cb94 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -490,9 +490,11 @@ plan_function_execution = PromptTemplate.from_template( You are an extremely methodical planner. Your goal is to make a plan to execute a function based on the user's query. {personality_context} - You have access to a variety of data sources to help you answer the user's question -- You can use the data sources listed below to collect more relevant information, one at a time +- You can use the data sources listed below to collect more relevant information, one at a time. The outputs will be chained. - You are given multiple iterations to with these data sources to answer the user's question - You are provided with additional context. If you have enough context to answer the question, then exit execution +- Each query is self-contained and you can use the data source to answer the user's question. There will be no additional data injected between queries, so make sure the query you're asking is answered in the current iteration. +- Limit each query to a *single* intention. For example, do not say "Look up the top city by population and output the GDP." Instead, say "Look up the top city by population." and then "Tell me the GDP of ." If you already know the answer to the question, return an empty response, e.g., {{}}. @@ -500,7 +502,7 @@ Which of the data sources listed below you would use to answer the user's questi {tools} -Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data source and associated query in a JSON object. Do not say anything else. +Provide the data source and associated query in a JSON object. Do not say anything else. Previous Iterations: {previous_iterations} @@ -520,8 +522,7 @@ previous_iteration = PromptTemplate.from_template( """ data_source: {data_source} query: {query} -context: {context} -onlineContext: {onlineContext} +summary: {summary} --- """.strip() ) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index f4061c50..158509bd 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -718,18 +718,20 @@ async def chat( ): if type(research_result) == InformationCollectionIteration: pending_research = False - if research_result.onlineContext: - researched_results += str(research_result.onlineContext) - online_results.update(research_result.onlineContext) + # if research_result.onlineContext: + # researched_results += str(research_result.onlineContext) + # online_results.update(research_result.onlineContext) - if research_result.context: - researched_results += str(research_result.context) - compiled_references.extend(research_result.context) + # if research_result.context: + # researched_results += str(research_result.context) + # compiled_references.extend(research_result.context) + + researched_results += research_result.summarizedResult else: yield research_result - researched_results = await extract_relevant_info(q, researched_results, agent) + # researched_results = await extract_relevant_info(q, researched_results, agent) logger.info(f"Researched Results: {researched_results}") diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index f6eae48d..3921cb78 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -13,6 +13,7 @@ from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( ChatEvent, construct_chat_history, + extract_relevant_info, generate_summary_from_files, send_message_to_model_wrapper, ) @@ -27,11 +28,19 @@ logger = logging.getLogger(__name__) class InformationCollectionIteration: - def __init__(self, data_source: str, query: str, context: str = None, onlineContext: dict = None): + def __init__( + self, + data_source: str, + query: str, + context: str = None, + onlineContext: dict = None, + summarizedResult: str = None, + ): self.data_source = data_source self.query = query self.context = context self.onlineContext = onlineContext + self.summarizedResult = summarizedResult async def apick_next_tool( @@ -63,8 +72,7 @@ async def apick_next_tool( iteration_data = prompts.previous_iteration.format( query=iteration.query, data_source=iteration.data_source, - context=str(iteration.context), - onlineContext=str(iteration.onlineContext), + summary=iteration.summarizedResult, ) previous_iterations_history += iteration_data @@ -138,7 +146,8 @@ async def execute_information_collection( compiled_references: List[Any] = [] inferred_queries: List[Any] = [] - defiltered_query = None + + result: str = "" this_iteration = await apick_next_tool( query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations @@ -165,13 +174,7 @@ async def execute_information_collection( compiled_references.extend(result[0]) inferred_queries.extend(result[1]) defiltered_query = result[2] - previous_iterations.append( - InformationCollectionIteration( - data_source=this_iteration.data_source, - query=this_iteration.query, - context=str(compiled_references), - ) - ) + this_iteration.context = str(compiled_references) elif this_iteration.data_source == ConversationCommand.Online: async for result in search_online( @@ -189,13 +192,7 @@ async def execute_information_collection( yield result[ChatEvent.STATUS] else: online_results = result - previous_iterations.append( - InformationCollectionIteration( - data_source=this_iteration.data_source, - query=this_iteration.query, - onlineContext=online_results, - ) - ) + this_iteration.onlineContext = online_results elif this_iteration.data_source == ConversationCommand.Webpage: async for result in read_webpages( @@ -224,57 +221,63 @@ async def execute_information_collection( webpages.append(webpage["link"]) yield send_status_func(f"**Read web pages**: {webpages}") - previous_iterations.append( - InformationCollectionIteration( - data_source=this_iteration.data_source, - query=this_iteration.query, - onlineContext=online_results, - ) - ) + this_iteration.onlineContext = online_results - elif this_iteration.data_source == ConversationCommand.Summarize: - response_log = "" - agent_has_entries = await EntryAdapters.aagent_has_entries(agent) - if len(file_filters) == 0 and not agent_has_entries: - previous_iterations.append( - InformationCollectionIteration( - data_source=this_iteration.data_source, - query=this_iteration.query, - context="No files selected for summarization.", - ) - ) - elif len(file_filters) > 1 and not agent_has_entries: - response_log = "Only one file can be selected for summarization." - previous_iterations.append( - InformationCollectionIteration( - data_source=this_iteration.data_source, - query=this_iteration.query, - context=response_log, - ) - ) - else: - async for response in generate_summary_from_files( - q=query, - user=user, - file_filters=file_filters, - meta_log=conversation_history, - subscribed=subscribed, - send_status_func=send_status_func, - ): - if isinstance(response, dict) and ChatEvent.STATUS in response: - yield response[ChatEvent.STATUS] - else: - response_log = response # type: ignore - previous_iterations.append( - InformationCollectionIteration( - data_source=this_iteration.data_source, - query=this_iteration.query, - context=response_log, - ) - ) + # TODO: Fix summarize later + # elif this_iteration.data_source == ConversationCommand.Summarize: + # response_log = "" + # agent_has_entries = await EntryAdapters.aagent_has_entries(agent) + # if len(file_filters) == 0 and not agent_has_entries: + # previous_iterations.append( + # InformationCollectionIteration( + # data_source=this_iteration.data_source, + # query=this_iteration.query, + # context="No files selected for summarization.", + # ) + # ) + # elif len(file_filters) > 1 and not agent_has_entries: + # response_log = "Only one file can be selected for summarization." + # previous_iterations.append( + # InformationCollectionIteration( + # data_source=this_iteration.data_source, + # query=this_iteration.query, + # context=response_log, + # ) + # ) + # else: + # async for response in generate_summary_from_files( + # q=query, + # user=user, + # file_filters=file_filters, + # meta_log=conversation_history, + # subscribed=subscribed, + # send_status_func=send_status_func, + # ): + # if isinstance(response, dict) and ChatEvent.STATUS in response: + # yield response[ChatEvent.STATUS] + # else: + # response_log = response # type: ignore + # previous_iterations.append( + # InformationCollectionIteration( + # data_source=this_iteration.data_source, + # query=this_iteration.query, + # context=response_log, + # ) + # ) else: iteration = MAX_ITERATIONS iteration += 1 - for completed_iter in previous_iterations: - yield completed_iter + + if compiled_references or online_results: + results_data = f"**Results**:\n" + if compiled_references: + results_data += f"**Document References**: {compiled_references}\n" + if online_results: + results_data += f"**Online Results**: {online_results}\n" + + intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) + this_iteration.summarizedResult = intermediate_result + + previous_iterations.append(this_iteration) + yield this_iteration From ab81b01fcbfd4bf39ba41db92f6148cab10dd812 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 17:46:28 -0700 Subject: [PATCH 11/88] Fix typing of direct_web_pages and remove the deprecated chat API --- src/khoj/routers/api_chat.py | 476 ----------------------------------- src/khoj/routers/research.py | 2 +- 2 files changed, 1 insertion(+), 477 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 158509bd..51cda191 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1064,479 +1064,3 @@ async def chat( response_iterator = event_generator(q, image=image) response_data = await read_chat_stream(response_iterator) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) - - -# @api_chat.post("") -@requires(["authenticated"]) -async def old_chat( - request: Request, - common: CommonQueryParams, - body: ChatRequestBody, - rate_limiter_per_minute=Depends( - ApiUserRateLimiter(requests=60, subscribed_requests=200, window=60, slug="chat_minute") - ), - rate_limiter_per_day=Depends( - ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day") - ), -): - # Access the parameters from the body - q = body.q - n = body.n - d = body.d - stream = body.stream - title = body.title - conversation_id = body.conversation_id - city = body.city - region = body.region - country = body.country or get_country_name_from_timezone(body.timezone) - country_code = body.country_code or get_country_code_from_timezone(body.timezone) - timezone = body.timezone - image = body.image - - async def event_generator(q: str, image: str): - start_time = time.perf_counter() - ttft = None - chat_metadata: dict = {} - connection_alive = True - user: KhojUser = request.user.object - subscribed: bool = has_required_scope(request, ["premium"]) - event_delimiter = "␃🔚␗" - q = unquote(q) - nonlocal conversation_id - - uploaded_image_url = None - if image: - decoded_string = unquote(image) - base64_data = decoded_string.split(",", 1)[1] - image_bytes = base64.b64decode(base64_data) - webp_image_bytes = convert_image_to_webp(image_bytes) - try: - uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id) - except: - uploaded_image_url = None - - async def send_event(event_type: ChatEvent, data: str | dict): - nonlocal connection_alive, ttft - if not connection_alive or await request.is_disconnected(): - connection_alive = False - logger.warning(f"User {user} disconnected from {common.client} client") - return - try: - if event_type == ChatEvent.END_LLM_RESPONSE: - collect_telemetry() - if event_type == ChatEvent.START_LLM_RESPONSE: - ttft = time.perf_counter() - start_time - if event_type == ChatEvent.MESSAGE: - yield data - elif event_type == ChatEvent.REFERENCES or stream: - yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) - except asyncio.CancelledError as e: - connection_alive = False - logger.warn(f"User {user} disconnected from {common.client} client: {e}") - return - except Exception as e: - connection_alive = False - logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) - return - finally: - yield event_delimiter - - async def send_llm_response(response: str): - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result - async for result in send_event(ChatEvent.MESSAGE, response): - yield result - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - - def collect_telemetry(): - # Gather chat response telemetry - nonlocal chat_metadata - latency = time.perf_counter() - start_time - cmd_set = set([cmd.value for cmd in conversation_commands]) - chat_metadata = chat_metadata or {} - chat_metadata["conversation_command"] = cmd_set - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - chat_metadata["latency"] = f"{latency:.3f}" - chat_metadata["ttft_latency"] = f"{ttft:.3f}" - - logger.info(f"Chat response time to first token: {ttft:.3f} seconds") - logger.info(f"Chat response total time: {latency:.3f} seconds") - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - client=request.user.client_app, - user_agent=request.headers.get("user-agent"), - host=request.headers.get("host"), - metadata=chat_metadata, - ) - - conversation_commands = [get_conversation_command(query=q, any_references=True)] - - conversation = await ConversationAdapters.aget_conversation_by_user( - user, - client_application=request.user.client_app, - conversation_id=conversation_id, - title=title, - create_new=body.create_new, - ) - if not conversation: - async for result in send_llm_response(f"Conversation {conversation_id} not found"): - yield result - return - conversation_id = conversation.id - - agent: Agent | None = None - default_agent = await AgentAdapters.aget_default_agent() - if conversation.agent and conversation.agent != default_agent: - agent = conversation.agent - - if not conversation.agent: - conversation.agent = default_agent - await conversation.asave() - agent = default_agent - - await is_ready_to_chat(user) - - user_name = await aget_user_name(user) - location = None - if city or region or country or country_code: - location = LocationData(city=city, region=region, country=country, country_code=country_code) - - if is_query_empty(q): - async for result in send_llm_response("Please ask your query to get started."): - yield result - return - - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - meta_log = conversation.conversation_log - is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - - if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_information_sources( - q, - meta_log, - is_automated_task, - subscribed=subscribed, - uploaded_image_url=uploaded_image_url, - agent=agent, - ) - conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) - async for result in send_event( - ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}" - ): - yield result - - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent) - async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): - yield result - if mode not in conversation_commands: - conversation_commands.append(mode) - - for cmd in conversation_commands: - await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - - used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] - file_filters = conversation.file_filters if conversation else [] - # Skip trying to summarize if - if ( - # summarization intent was inferred - ConversationCommand.Summarize in conversation_commands - # and not triggered via slash command - and not used_slash_summarize - # but we can't actually summarize - and len(file_filters) != 1 - ): - conversation_commands.remove(ConversationCommand.Summarize) - elif ConversationCommand.Summarize in conversation_commands: - response_log = "" - agent_has_entries = await EntryAdapters.aagent_has_entries(agent) - if len(file_filters) == 0 and not agent_has_entries: - response_log = "No files selected for summarization. Please add files using the section on the left." - async for result in send_llm_response(response_log): - yield result - elif len(file_filters) > 1 and not agent_has_entries: - response_log = "Only one file can be selected for summarization." - async for result in send_llm_response(response_log): - yield result - else: - response_log = await generate_summary_from_files( - q=q, - user=user, - file_filters=file_filters, - meta_log=meta_log, - subscribed=subscribed, - send_status_func=partial(send_event, ChatEvent.STATUS), - send_response_func=partial(send_llm_response), - ) - await sync_to_async(save_to_conversation_log)( - q, - response_log, - user, - meta_log, - user_message_time, - intent_type="summarize", - client_application=request.user.client_app, - conversation_id=conversation_id, - uploaded_image_url=uploaded_image_url, - ) - return - - custom_filters = [] - if conversation_commands == [ConversationCommand.Help]: - if not q: - conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() - model_type = conversation_config.model_type - formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) - async for result in send_llm_response(formatted_help): - yield result - return - # Adding specification to search online specifically on khoj.dev pages. - custom_filters.append("site:khoj.dev") - conversation_commands.append(ConversationCommand.Online) - - if ConversationCommand.Automation in conversation_commands: - try: - automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log - ) - except Exception as e: - logger.error(f"Error scheduling task {q} for {user.email}: {e}") - error_message = f"Unable to create automation. Ensure the automation doesn't already exist." - async for result in send_llm_response(error_message): - yield result - return - - llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - meta_log, - user_message_time, - intent_type="automation", - client_application=request.user.client_app, - conversation_id=conversation_id, - inferred_queries=[query_to_run], - automation_id=automation.id, - uploaded_image_url=uploaded_image_url, - ) - async for result in send_llm_response(llm_response): - yield result - return - - # Gather Context - ## Extract Document References - compiled_references, inferred_queries, defiltered_query = [], [], None - async for result in extract_references_and_questions( - request, - meta_log, - q, - (n or 7), - d, - conversation_id, - conversation_commands, - location, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - agent=agent, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] - - if not is_none_or_empty(compiled_references): - headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - # Strip only leading # from headings - headings = headings.replace("#", "") - async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): - yield result - - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - async for result in send_llm_response(f"{no_entries_found.format()}"): - yield result - return - - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) - - ## Gather Online References - if ConversationCommand.Online in conversation_commands: - try: - async for result in search_online( - defiltered_query, - meta_log, - location, - user, - subscribed, - partial(send_event, ChatEvent.STATUS), - custom_filters, - uploaded_image_url=uploaded_image_url, - agent=agent, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - online_results = result - except ValueError as e: - error_message = f"Error searching online: {e}. Attempting to respond without online results" - logger.warning(error_message) - async for result in send_llm_response(error_message): - yield result - return - - ## Gather Webpage References - if ConversationCommand.Webpage in conversation_commands: - try: - async for result in read_webpages( - defiltered_query, - meta_log, - location, - user, - subscribed, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - agent=agent, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - direct_web_pages = result - webpages = [] - for query in direct_web_pages: - if online_results.get(query): - online_results[query]["webpages"] = direct_web_pages[query]["webpages"] - else: - online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} - - for webpage in direct_web_pages[query]["webpages"]: - webpages.append(webpage["link"]) - async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): - yield result - except ValueError as e: - logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", - exc_info=True, - ) - - ## Send Gathered References - async for result in send_event( - ChatEvent.REFERENCES, - { - "inferredQueries": inferred_queries, - "context": compiled_references, - "onlineContext": online_results, - }, - ): - yield result - - # Generate Output - ## Generate Image Output - if ConversationCommand.Image in conversation_commands: - async for result in text_to_image( - q, - user, - meta_log, - location_data=location, - references=compiled_references, - online_results=online_results, - subscribed=subscribed, - send_status_func=partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - agent=agent, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - image, status_code, improved_image_prompt, intent_type = result - - if image is None or status_code != 200: - content_obj = { - "content-type": "application/json", - "intentType": intent_type, - "detail": improved_image_prompt, - "image": image, - } - async for result in send_llm_response(json.dumps(content_obj)): - yield result - return - - await sync_to_async(save_to_conversation_log)( - q, - image, - user, - meta_log, - user_message_time, - intent_type=intent_type, - inferred_queries=[improved_image_prompt], - client_application=request.user.client_app, - conversation_id=conversation_id, - compiled_references=compiled_references, - online_results=online_results, - uploaded_image_url=uploaded_image_url, - ) - content_obj = { - "intentType": intent_type, - "inferredQueries": [improved_image_prompt], - "image": image, - } - async for result in send_llm_response(json.dumps(content_obj)): - yield result - return - - ## Generate Text Output - async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): - yield result - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - meta_log, - conversation, - compiled_references, - online_results, - inferred_queries, - conversation_commands, - user, - request.user.client_app, - conversation_id, - location, - user_name, - uploaded_image_url, - ) - - # Send Response - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result - - continue_stream = True - iterator = AsyncIteratorWrapper(llm_response) - async for item in iterator: - if item is None: - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - logger.debug("Finished streaming response") - return - if not connection_alive or not continue_stream: - continue - try: - async for result in send_event(ChatEvent.MESSAGE, f"{item}"): - yield result - except Exception as e: - continue_stream = False - logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") - - ## Stream Text Response - if stream: - return StreamingResponse(event_generator(q, image=image), media_type="text/plain") - ## Non-Streaming Text Response - else: - response_iterator = event_generator(q, image=image) - response_data = await read_chat_stream(response_iterator) - return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 3921cb78..dbbb9996 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -208,7 +208,7 @@ async def execute_information_collection( if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] else: - direct_web_pages = result + direct_web_pages: Dict[str, Dict] = result webpages = [] for query in direct_web_pages: From 03544efde2826a000e129960b58e609135a57e6a Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 17:48:24 -0700 Subject: [PATCH 12/88] Ignore typing of the result dict for online, web page scrape --- src/khoj/routers/research.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index dbbb9996..837fb228 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -191,7 +191,7 @@ async def execute_information_collection( if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] else: - online_results = result + online_results: Dict[str, Dict] = result # type: ignore this_iteration.onlineContext = online_results elif this_iteration.data_source == ConversationCommand.Webpage: @@ -208,7 +208,7 @@ async def execute_information_collection( if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] else: - direct_web_pages: Dict[str, Dict] = result + direct_web_pages: Dict[str, Dict] = result # type: ignore webpages = [] for query in direct_web_pages: From 717d9da8d87532ec2b4f087e3cdd682021312c31 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 17:57:08 -0700 Subject: [PATCH 13/88] Handle when summarize result is not present, rename variable in for loop from query --- src/khoj/routers/api_chat.py | 17 +++++++++-------- src/khoj/routers/research.py | 10 +++++----- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 51cda191..b5394d8d 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -717,16 +717,17 @@ async def chat( file_filters=conversation.file_filters if conversation else [], ): if type(research_result) == InformationCollectionIteration: - pending_research = False - # if research_result.onlineContext: - # researched_results += str(research_result.onlineContext) - # online_results.update(research_result.onlineContext) + if research_result.summarizedResult: + pending_research = False + # if research_result.onlineContext: + # researched_results += str(research_result.onlineContext) + # online_results.update(research_result.onlineContext) - # if research_result.context: - # researched_results += str(research_result.context) - # compiled_references.extend(research_result.context) + # if research_result.context: + # researched_results += str(research_result.context) + # compiled_references.extend(research_result.context) - researched_results += research_result.summarizedResult + researched_results += research_result.summarizedResult else: yield research_result diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 837fb228..60d00cd9 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -211,13 +211,13 @@ async def execute_information_collection( direct_web_pages: Dict[str, Dict] = result # type: ignore webpages = [] - for query in direct_web_pages: - if online_results.get(query): - online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + for web_query in direct_web_pages: + if online_results.get(web_query): + online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"] else: - online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]} - for webpage in direct_web_pages[query]["webpages"]: + for webpage in direct_web_pages[web_query]["webpages"]: webpages.append(webpage["link"]) yield send_status_func(f"**Read web pages**: {webpages}") From 028b6e6379f66639d8bb31c1bcaa9396a24024ff Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 18:14:08 -0700 Subject: [PATCH 14/88] Fix yield for scraping direct web page --- src/khoj/routers/research.py | 51 ++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 60d00cd9..b3327f35 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -195,33 +195,34 @@ async def execute_information_collection( this_iteration.onlineContext = online_results elif this_iteration.data_source == ConversationCommand.Webpage: - async for result in read_webpages( - this_iteration.query, - conversation_history, - location, - user, - subscribed, - send_status_func, - uploaded_image_url=uploaded_image_url, - agent=agent, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - direct_web_pages: Dict[str, Dict] = result # type: ignore + try: + async for result in read_webpages( + this_iteration.query, + conversation_history, + location, + user, + subscribed, + send_status_func, + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + direct_web_pages: Dict[str, Dict] = result # type: ignore - webpages = [] - for web_query in direct_web_pages: - if online_results.get(web_query): - online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"] - else: - online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]} + webpages = [] + for web_query in direct_web_pages: + if online_results.get(web_query): + online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"] + else: + online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]} - for webpage in direct_web_pages[web_query]["webpages"]: - webpages.append(webpage["link"]) - yield send_status_func(f"**Read web pages**: {webpages}") - - this_iteration.onlineContext = online_results + for webpage in direct_web_pages[web_query]["webpages"]: + webpages.append(webpage["link"]) + this_iteration.onlineContext = online_results + except Exception as e: + logger.error(f"Error reading webpages: {e}", exc_info=True) # TODO: Fix summarize later # elif this_iteration.data_source == ConversationCommand.Summarize: From a6905a9f0cea8d30662094408020c75467dce3b0 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 9 Oct 2024 19:00:54 -0700 Subject: [PATCH 15/88] Pass background context to iterating chat director --- src/khoj/processor/conversation/prompts.py | 8 ++++++-- src/khoj/routers/api_chat.py | 1 + src/khoj/routers/research.py | 15 ++++++++++++++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 6ef1cb94..ae4f8230 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -498,6 +498,11 @@ You are an extremely methodical planner. Your goal is to make a plan to execute If you already know the answer to the question, return an empty response, e.g., {{}}. +Background Context: +- Current Date: {day_of_week}, {current_date} +- User's Location: {location} +- {username} + Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources: {tools} @@ -523,8 +528,7 @@ previous_iteration = PromptTemplate.from_template( data_source: {data_source} query: {query} summary: {summary} ---- -""".strip() +---""" ) pick_relevant_information_collection_tools = PromptTemplate.from_template( diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index b5394d8d..505c0248 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -713,6 +713,7 @@ async def chat( uploaded_image_url=uploaded_image_url, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), + user_name=user_name, location=location, file_filters=conversation.file_filters if conversation else [], ): diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index b3327f35..028fffb7 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -1,5 +1,6 @@ import json import logging +from datetime import datetime from typing import Any, Callable, Dict, List, Optional from fastapi import Request @@ -48,6 +49,8 @@ async def apick_next_tool( conversation_history: dict, subscribed: bool, uploaded_image_url: str = None, + location: LocationData = None, + user_name: str = None, agent: Agent = None, previous_iterations: List[InformationCollectionIteration] = None, ): @@ -84,12 +87,21 @@ async def apick_next_tool( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" ) + # Extract Past User Message and Inferred Questions from Conversation Log + today = datetime.today() + location_data = f"{location}" if location else "Unknown" + username = prompts.user_name.format(name=user_name) if user_name else "" + # TODO Add current date/time to the query function_planning_prompt = prompts.plan_function_execution.format( query=query, tools=tool_options_str, chat_history=chat_history, personality_context=personality_context, + current_date=today.strftime("%Y-%m-%d"), + day_of_week=today.strftime("%A"), + username=username, + location=location_data, previous_iterations=previous_iterations_history, ) @@ -135,6 +147,7 @@ async def execute_information_collection( uploaded_image_url: str = None, agent: Agent = None, send_status_func: Optional[Callable] = None, + user_name: str = None, location: LocationData = None, file_filters: List[str] = [], ): @@ -150,7 +163,7 @@ async def execute_information_collection( result: str = "" this_iteration = await apick_next_tool( - query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations + query, conversation_history, subscribed, uploaded_image_url, location, user_name, agent, previous_iterations ) if this_iteration.data_source == ConversationCommand.Notes: ## Extract Document References From ec248efd3144663fbcd7e181ce140a517c41cd7d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 9 Oct 2024 19:01:34 -0700 Subject: [PATCH 16/88] Allow iterative chat director to do notes search --- src/khoj/routers/api.py | 7 ++--- src/khoj/routers/api_chat.py | 60 ++++++++++++++++++------------------ 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index d26b7b5a..11ab1112 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -355,10 +355,9 @@ async def extract_references_and_questions( agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent) if ( - # not ConversationCommand.Notes in conversation_commands - # and not ConversationCommand.Default in conversation_commands - # and not agent_has_entries - True + not ConversationCommand.Notes in conversation_commands + and not ConversationCommand.Default in conversation_commands + and not agent_has_entries ): yield compiled_references, inferred_queries, q return diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 505c0248..89b56a29 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -853,37 +853,37 @@ async def chat( yield result return - # Gather Context - async for result in extract_references_and_questions( - request, - meta_log, - q, - (n or 7), - d, - conversation_id, - conversation_commands, - location, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - agent=agent, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] + # # Gather Context + # async for result in extract_references_and_questions( + # request, + # meta_log, + # q, + # (n or 7), + # d, + # conversation_id, + # conversation_commands, + # location, + # partial(send_event, ChatEvent.STATUS), + # uploaded_image_url=uploaded_image_url, + # agent=agent, + # ): + # if isinstance(result, dict) and ChatEvent.STATUS in result: + # yield result[ChatEvent.STATUS] + # else: + # compiled_references.extend(result[0]) + # inferred_queries.extend(result[1]) + # defiltered_query = result[2] - if not is_none_or_empty(compiled_references): - try: - headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - # Strip only leading # from headings - headings = headings.replace("#", "") - async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): - yield result - except Exception as e: - # TODO Get correct type for compiled across research notes extraction - logger.error(f"Error extracting references: {e}", exc_info=True) + # if not is_none_or_empty(compiled_references): + # try: + # headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) + # # Strip only leading # from headings + # headings = headings.replace("#", "") + # async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): + # yield result + # except Exception as e: + # # TODO Get correct type for compiled across research notes extraction + # logger.error(f"Error extracting references: {e}", exc_info=True) if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): async for result in send_llm_response(f"{no_entries_found.format()}"): From a6f6e4f418c328f9d9cf74fb244a10ab40d2428d Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 20:34:20 -0700 Subject: [PATCH 17/88] Fix notes references and passage of user query in the chat flow --- src/khoj/routers/api_chat.py | 14 ++++++-------- src/khoj/routers/helpers.py | 2 +- src/khoj/routers/research.py | 4 ++-- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 89b56a29..b2689a7f 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -720,13 +720,11 @@ async def chat( if type(research_result) == InformationCollectionIteration: if research_result.summarizedResult: pending_research = False - # if research_result.onlineContext: - # researched_results += str(research_result.onlineContext) - # online_results.update(research_result.onlineContext) + if research_result.onlineContext: + online_results.update(research_result.onlineContext) - # if research_result.context: - # researched_results += str(research_result.context) - # compiled_references.extend(research_result.context) + if research_result.context: + compiled_references.extend(research_result.context) researched_results += research_result.summarizedResult @@ -1021,10 +1019,9 @@ async def chat( async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): yield result llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, + q, meta_log, conversation, - researched_results, compiled_references, online_results, inferred_queries, @@ -1035,6 +1032,7 @@ async def chat( location, user_name, uploaded_image_url, + researched_results, ) # Send Response diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 279ad85e..e4ebdb51 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -947,7 +947,6 @@ def generate_chat_response( q: str, meta_log: dict, conversation: Conversation, - meta_research: str = "", compiled_references: List[Dict] = [], online_results: Dict[str, Dict] = {}, inferred_queries: List[str] = [], @@ -958,6 +957,7 @@ def generate_chat_response( location_data: LocationData = None, user_name: Optional[str] = None, uploaded_image_url: Optional[str] = None, + meta_research: str = "", ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables chat_response = None diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 028fffb7..5960c88c 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -33,7 +33,7 @@ class InformationCollectionIteration: self, data_source: str, query: str, - context: str = None, + context: Dict[str, Dict] = None, onlineContext: dict = None, summarizedResult: str = None, ): @@ -187,7 +187,7 @@ async def execute_information_collection( compiled_references.extend(result[0]) inferred_queries.extend(result[1]) defiltered_query = result[2] - this_iteration.context = str(compiled_references) + this_iteration.context = compiled_references elif this_iteration.data_source == ConversationCommand.Online: async for result in search_online( From 6ad85e22756f1830e68b80bc235ae86ba3198b28 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 9 Oct 2024 21:19:04 -0700 Subject: [PATCH 18/88] Fix to continue showing retrieved documents in train of thought --- src/khoj/routers/research.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 5960c88c..fd52142b 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -21,6 +21,7 @@ from khoj.routers.helpers import ( from khoj.utils.helpers import ( ConversationCommand, function_calling_description_for_llm, + is_none_or_empty, timer, ) from khoj.utils.rawconfig import LocationData @@ -189,6 +190,17 @@ async def execute_information_collection( defiltered_query = result[2] this_iteration.context = compiled_references + if not is_none_or_empty(compiled_references): + try: + headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) + # Strip only leading # from headings + headings = headings.replace("#", "") + async for result in send_status_func(f"**Found Relevant Notes**: {headings}"): + yield result + except Exception as e: + # TODO Get correct type for compiled across research notes extraction + logger.error(f"Error extracting references: {e}", exc_info=True) + elif this_iteration.data_source == ConversationCommand.Online: async for result in search_online( this_iteration.query, From 4d33239af67107da85dab8968db43f9aa6fa15d3 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 9 Oct 2024 21:23:18 -0700 Subject: [PATCH 19/88] Improve prompts for the iterative chat director --- src/khoj/processor/conversation/prompts.py | 36 ++++++++++++++-------- src/khoj/routers/research.py | 27 +++++++++++----- src/khoj/utils/helpers.py | 6 ++-- 3 files changed, 45 insertions(+), 24 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index ae4f8230..0738af4e 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -487,16 +487,24 @@ Khoj: plan_function_execution = PromptTemplate.from_template( """ -You are an extremely methodical planner. Your goal is to make a plan to execute a function based on the user's query. +You are a smart, methodical researcher. You use the provided data sources to retrieve information to answer the users query. +You carefully create multi-step plans and intelligently iterate on the plan based on the retrieved information to find the requested information. {personality_context} -- You have access to a variety of data sources to help you answer the user's question -- You can use the data sources listed below to collect more relevant information, one at a time. The outputs will be chained. -- You are given multiple iterations to with these data sources to answer the user's question -- You are provided with additional context. If you have enough context to answer the question, then exit execution -- Each query is self-contained and you can use the data source to answer the user's question. There will be no additional data injected between queries, so make sure the query you're asking is answered in the current iteration. -- Limit each query to a *single* intention. For example, do not say "Look up the top city by population and output the GDP." Instead, say "Look up the top city by population." and then "Tell me the GDP of ." - -If you already know the answer to the question, return an empty response, e.g., {{}}. +- Use the data sources provided below, one at a time, if you need to find more information. Their output will be shown to you in the next iteration. +- You are allowed upto {max_iterations} iterations to use these data sources to answer the user's question +- If you have enough information to answer the question, then exit execution by returning an empty response. E.g., {{}} +- Ensure the query contains enough context to retrieve relevant information from the data sources. +- Break down the problem into smaller steps. Some examples are provided below assuming you have access to the notes and online data sources: + - If the user asks for the population of their hometown + 1. Try look up their hometown in their notes + 2. Only then try find the population of the city online. + - If the user asks for their computer's specs + 1. Try find the computer model in their notes + 2. Now look up their computer models spec online + - If the user asks what clothes to carry for their upcoming trip + 1. Find the itinerary of their upcoming trip in their notes + 2. Next find the weather forecast at the destination online + 3. Then find if they mention what clothes they own in their notes Background Context: - Current Date: {day_of_week}, {current_date} @@ -525,10 +533,12 @@ Response: previous_iteration = PromptTemplate.from_template( """ -data_source: {data_source} -query: {query} -summary: {summary} ----""" +# Iteration {index}: +# --- +- data_source: {data_source} +- query: {query} +- summary: {summary} +""" ) pick_relevant_information_collection_tools = PromptTemplate.from_template( diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index fd52142b..2974b511 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -54,6 +54,7 @@ async def apick_next_tool( user_name: str = None, agent: Agent = None, previous_iterations: List[InformationCollectionIteration] = None, + max_iterations: int = 5, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer. @@ -72,11 +73,12 @@ async def apick_next_tool( chat_history = construct_chat_history(conversation_history) previous_iterations_history = "" - for iteration in previous_iterations: + for idx, iteration in enumerate(previous_iterations): iteration_data = prompts.previous_iteration.format( query=iteration.query, data_source=iteration.data_source, summary=iteration.summarizedResult, + index=idx + 1, ) previous_iterations_history += iteration_data @@ -104,6 +106,7 @@ async def apick_next_tool( username=username, location=location_data, previous_iterations=previous_iterations_history, + max_iterations=max_iterations, ) chat_model_option = await ConversationAdapters.aget_advanced_conversation_config() @@ -152,10 +155,10 @@ async def execute_information_collection( location: LocationData = None, file_filters: List[str] = [], ): - iteration = 0 + current_iteration = 0 MAX_ITERATIONS = 2 previous_iterations: List[InformationCollectionIteration] = [] - while iteration < MAX_ITERATIONS: + while current_iteration < MAX_ITERATIONS: online_results: Dict = dict() compiled_references: List[Any] = [] @@ -164,7 +167,15 @@ async def execute_information_collection( result: str = "" this_iteration = await apick_next_tool( - query, conversation_history, subscribed, uploaded_image_url, location, user_name, agent, previous_iterations + query, + conversation_history, + subscribed, + uploaded_image_url, + location, + user_name, + agent, + previous_iterations, + MAX_ITERATIONS, ) if this_iteration.data_source == ConversationCommand.Notes: ## Extract Document References @@ -291,9 +302,9 @@ async def execute_information_collection( # ) # ) else: - iteration = MAX_ITERATIONS + current_iteration = MAX_ITERATIONS - iteration += 1 + current_iteration += 1 if compiled_references or online_results: results_data = f"**Results**:\n" @@ -302,8 +313,8 @@ async def execute_information_collection( if online_results: results_data += f"**Online Results**: {online_results}\n" - intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) - this_iteration.summarizedResult = intermediate_result + # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) + this_iteration.summarizedResult = results_data previous_iterations.append(this_iteration) yield this_iteration diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 25243a27..9ed8ffa2 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -346,9 +346,9 @@ tool_descriptions_for_llm = { } function_calling_description_for_llm = { - ConversationCommand.Notes: "Use this if you think the user's personal knowledge base contains relevant context.", - ConversationCommand.Online: "Use this if you think the there's important information on the internet related to the query.", - ConversationCommand.Webpage: "Use this if the user has provided a webpage URL or you are share of a webpage URL that will help you directly answer this query", + ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.", + ConversationCommand.Online: "To search for the latest, up-to-date information from the internet.", + ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.", } mode_descriptions_for_llm = { From 804473320103df69df24601cd40d6953860f00ee Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 9 Oct 2024 13:37:06 -0700 Subject: [PATCH 20/88] Give Khoj ability to run python code as a tool triggered via chat API Create python code executing chat actor - The chat actor generate python code within sandbox constraints - Run the generated python code in the cohere terrarium, pyodide based sandbox accessible at sandbox url --- .../conversation/anthropic/anthropic_chat.py | 5 + .../conversation/google/gemini_chat.py | 5 + .../conversation/offline/chat_model.py | 7 +- src/khoj/processor/conversation/openai/gpt.py | 5 + src/khoj/processor/conversation/prompts.py | 42 +++++++ src/khoj/processor/conversation/utils.py | 2 + src/khoj/routers/api_chat.py | 28 ++++- src/khoj/routers/helpers.py | 103 ++++++++++++++++++ src/khoj/utils/helpers.py | 7 +- 9 files changed, 200 insertions(+), 4 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index cb51abb4..c980c00d 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -126,6 +126,7 @@ def converse_anthropic( references, user_query, online_results: Optional[Dict[str, Dict]] = None, + code_results: Optional[Dict[str, Dict]] = None, conversation_log={}, model: Optional[str] = "claude-instant-1.2", api_key: Optional[str] = None, @@ -175,6 +176,10 @@ def converse_anthropic( completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) + if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results): + conversation_primer = ( + f"{prompts.code_executed_context.format(code_results=str(code_results))}\n{conversation_primer}" + ) if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: conversation_primer = ( f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}" diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 7359b3eb..5735799e 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -122,6 +122,7 @@ def converse_gemini( references, user_query, online_results: Optional[Dict[str, Dict]] = None, + code_results: Optional[Dict[str, Dict]] = None, conversation_log={}, model: Optional[str] = "gemini-1.5-flash", api_key: Optional[str] = None, @@ -173,6 +174,10 @@ def converse_gemini( completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) + if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results): + conversation_primer = ( + f"{prompts.code_executed_context.format(code_results=str(code_results))}\n{conversation_primer}" + ) if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: conversation_primer = ( f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}" diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 4eafae00..d9d99f21 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -135,7 +135,8 @@ def filter_questions(questions: List[str]): def converse_offline( user_query, references=[], - online_results=[], + online_results={}, + code_results={}, conversation_log={}, model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", loaded_model: Union[Any, None] = None, @@ -187,6 +188,10 @@ def converse_offline( completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) + if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results): + conversation_primer = ( + f"{prompts.code_executed_context.format(code_results=str(code_results))}\n{conversation_primer}" + ) if ConversationCommand.Online in conversation_commands: simplified_online_results = online_results.copy() for result in online_results: diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index ad02b10e..a4850cfd 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -123,6 +123,7 @@ def converse( references, user_query, online_results: Optional[Dict[str, Dict]] = None, + code_results: Optional[Dict[str, Dict]] = None, conversation_log={}, model: str = "gpt-4o-mini", api_key: Optional[str] = None, @@ -176,6 +177,10 @@ def converse( completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) + if not is_none_or_empty(code_results): + conversation_primer = ( + f"{prompts.code_executed_context.format(code_results=str(code_results))}\n{conversation_primer}" + ) if not is_none_or_empty(online_results): conversation_primer = ( f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}" diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 0738af4e..23788cab 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -730,6 +730,48 @@ Khoj: """.strip() ) +# Code Generation +# -- +python_code_generation_prompt = PromptTemplate.from_template( + """ +You are Khoj, an advanced python programmer. You are tasked with constructing **up to three** python programs to best answer the user query. +- The python program will run in a pyodide python sandbox with no network access. +- You can write programs to run complex calculations, analyze data, create charts, generate documents to meticulously answer the query +- The sandbox has access to the standard library, matplotlib, panda, numpy, scipy, bs4, sympy, brotli, cryptography, fast-parquet +- Do not try display images or plots in the code directly. The code should save the image or plot to a file instead. +- Write any document, charts etc. to be shared with the user to file. These files can be seen by the user. +- Use as much context from the previous questions and answers as required to generate your code. +{personality_context} +What code will you need to write, if any, to answer the user's question? +Provide code programs as a list of strings in a JSON object with key "codes". +Current Date: {current_date} +User's Location: {location} +{username} + +The JSON schema is of the form {{"codes": ["code1", "code2", "code3"]}} +For example: +{{"codes": ["print('Hello, World!')", "print('Goodbye, World!')"]}} + +Now it's your turn to construct python programs to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else. +History: +{chat_history} + +User: {query} +Khoj: +""".strip() +) + +code_executed_context = PromptTemplate.from_template( + """ +Use the provided code executions to inform your response. +Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided code execution results or past conversations. + +Code Execution Results: +{code_results} +""".strip() +) + + # Automations # -- crontime_prompt = PromptTemplate.from_template( diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e841c484..9a2ba230 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -104,6 +104,7 @@ def save_to_conversation_log( user_message_time: str = None, compiled_references: List[Dict[str, Any]] = [], online_results: Dict[str, Any] = {}, + code_results: Dict[str, Any] = {}, inferred_queries: List[str] = [], intent_type: str = "remember", client_application: ClientApplication = None, @@ -123,6 +124,7 @@ def save_to_conversation_log( "context": compiled_references, "intent": {"inferred-queries": inferred_queries, "type": intent_type}, "onlineContext": online_results, + "codeContext": code_results, "automationId": automation_id, }, conversation_log=meta_log.get("chat", []), diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index b2689a7f..cdf16bd9 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -3,7 +3,6 @@ import base64 import json import logging import time -import warnings from datetime import datetime from functools import partial from typing import Any, Dict, List, Optional @@ -47,6 +46,7 @@ from khoj.routers.helpers import ( is_query_empty, is_ready_to_chat, read_chat_stream, + run_code, update_telemetry_state, validate_conversation_config, ) @@ -950,6 +950,30 @@ async def chat( exc_info=True, ) + ## Gather Code Results + if ConversationCommand.Code in conversation_commands: + try: + async for result in run_code( + defiltered_query, + meta_log, + location, + user, + partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + code_results = result + async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"): + yield result + except ValueError as e: + logger.warning( + f"Failed to use code tool: {e}. Attempting to respond without code results", + exc_info=True, + ) + ## Send Gathered References async for result in send_event( ChatEvent.REFERENCES, @@ -957,6 +981,7 @@ async def chat( "inferredQueries": inferred_queries, "context": compiled_references, "onlineContext": online_results, + "codeContext": code_results, }, ): yield result @@ -1024,6 +1049,7 @@ async def chat( conversation, compiled_references, online_results, + code_results, inferred_queries, conversation_commands, user, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e4ebdb51..f17fe1f5 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -24,6 +24,7 @@ from typing import ( ) from urllib.parse import parse_qs, quote, urljoin, urlparse +import aiohttp import cron_descriptor import pytz import requests @@ -519,6 +520,103 @@ async def generate_online_subqueries( return [q] +async def run_code( + query: str, + conversation_history: dict, + location_data: LocationData, + user: KhojUser, + send_status_func: Optional[Callable] = None, + uploaded_image_url: str = None, + agent: Agent = None, + sandbox_url: str = "http://localhost:8080", +): + # Generate Code + if send_status_func: + async for event in send_status_func(f"**Generate code snippets** for {query}"): + yield {ChatEvent.STATUS: event} + try: + with timer("Chat actor: Generate programs to execute", logger): + codes = await generate_python_code( + query, conversation_history, location_data, user, uploaded_image_url, agent + ) + except Exception as e: + raise ValueError(f"Failed to generate code for {query} with error: {e}") + + # Run Code + if send_status_func: + async for event in send_status_func(f"**Running {len(codes)} code snippets**"): + yield {ChatEvent.STATUS: event} + try: + tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes] + with timer("Chat actor: Execute generated programs", logger): + results = await asyncio.gather(*tasks) + for result in results: + code = result.pop("code") + logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--") + yield {query: {"code": code, "results": result}} + except Exception as e: + raise ValueError(f"Failed to run code for {query} with error: {e}") + + +async def generate_python_code( + q: str, + conversation_history: dict, + location_data: LocationData, + user: KhojUser, + uploaded_image_url: str = None, + agent: Agent = None, +) -> List[str]: + location = f"{location_data}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" + chat_history = construct_chat_history(conversation_history) + + utc_date = datetime.utcnow().strftime("%Y-%m-%d") + personality_context = ( + prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" + ) + + code_generation_prompt = prompts.python_code_generation_prompt.format( + current_date=utc_date, + query=q, + chat_history=chat_history, + location=location, + username=username, + personality_context=personality_context, + ) + + response = await send_message_to_model_wrapper( + code_generation_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user + ) + + # Validate that the response is a non-empty, JSON-serializable list + response = response.strip() + response = remove_json_codeblock(response) + response = json.loads(response) + codes = [code.strip() for code in response["codes"] if code.strip()] + + if not isinstance(codes, list) or not codes or len(codes) == 0: + raise ValueError + return codes + + +async def execute_sandboxed_python(code: str, sandbox_url: str = "http://localhost:8080") -> dict[str, Any]: + """ + Takes code to run as a string and calls the terrarium API to execute it. + Returns the result of the code execution as a dictionary. + """ + headers = {"Content-Type": "application/json"} + data = {"code": code} + + async with aiohttp.ClientSession() as session: + async with session.post(sandbox_url, json=data, headers=headers) as response: + if response.status == 200: + result: dict[str, Any] = await response.json() + result["code"] = code + return result + else: + return {"code": code, "success": False, "std_err": f"Failed to execute code with {response.status}"} + + async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. @@ -949,6 +1047,7 @@ def generate_chat_response( conversation: Conversation, compiled_references: List[Dict] = [], online_results: Dict[str, Dict] = {}, + code_results: Dict[str, Dict] = {}, inferred_queries: List[str] = [], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], user: KhojUser = None, @@ -976,6 +1075,7 @@ def generate_chat_response( meta_log=meta_log, compiled_references=compiled_references, online_results=online_results, + code_results=code_results, inferred_queries=inferred_queries, client_application=client_application, conversation_id=conversation_id, @@ -1017,6 +1117,7 @@ def generate_chat_response( query_to_run, image_url=uploaded_image_url, online_results=online_results, + code_results=code_results, conversation_log=meta_log, model=chat_model, api_key=api_key, @@ -1037,6 +1138,7 @@ def generate_chat_response( compiled_references, query_to_run, online_results, + code_results, meta_log, model=conversation_config.chat_model, api_key=api_key, @@ -1054,6 +1156,7 @@ def generate_chat_response( compiled_references, query_to_run, online_results, + code_results, meta_log, model=conversation_config.chat_model, api_key=api_key, diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 9ed8ffa2..0e0193a9 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -309,6 +309,7 @@ class ConversationCommand(str, Enum): Help = "help" Online = "online" Webpage = "webpage" + Code = "code" Image = "image" Text = "text" Automation = "automation" @@ -322,6 +323,7 @@ command_descriptions = { ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.", ConversationCommand.Online: "Search for information on the internet.", ConversationCommand.Webpage: "Get information from webpage suggested by you.", + ConversationCommand.Code: "Run Python code to parse information, run complex calculations, create documents and charts.", ConversationCommand.Image: "Generate images by describing your imagination in words.", ConversationCommand.Automation: "Automatically run your query at a specified time or interval.", ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation", @@ -342,6 +344,7 @@ tool_descriptions_for_llm = { ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.", ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**", ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.", + ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.", ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.", } @@ -352,13 +355,13 @@ function_calling_description_for_llm = { } mode_descriptions_for_llm = { - ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.", + ConversationCommand.Image: "Use this if the user is requesting you to generate images based on their description. This does not support generating charts or graphs.", ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.", ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.", } mode_descriptions_for_agent = { - ConversationCommand.Image: "Agent can generate image in response.", + ConversationCommand.Image: "Agent can generate images in response. It cannot not use this to generate charts and graphs.", ConversationCommand.Text: "Agent can generate text in response.", } From a98f97ed5e97cc9f2ac2cec5179b5493b7755788 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 9 Oct 2024 15:54:54 -0700 Subject: [PATCH 21/88] Refactor Run Code tool into separate module and modularize code functions Move construct_chat_history and ChatEvent enum into conversation.utils and move send_message_to_model_wrapper to conversation.helper to modularize code. And start thinning out the bloated routers.helper - conversation.util components are shared functions that conversation child packages can use. - conversation.helper components can't be imported by conversation packages but it can use these child packages This division allows better modularity while avoiding circular import dependencies --- src/khoj/processor/conversation/helpers.py | 126 +++++++++++ src/khoj/processor/conversation/utils.py | 23 +- src/khoj/processor/tools/run_code.py | 122 +++++++++++ src/khoj/routers/api_chat.py | 3 +- src/khoj/routers/helpers.py | 234 +-------------------- 5 files changed, 274 insertions(+), 234 deletions(-) create mode 100644 src/khoj/processor/conversation/helpers.py create mode 100644 src/khoj/processor/tools/run_code.py diff --git a/src/khoj/processor/conversation/helpers.py b/src/khoj/processor/conversation/helpers.py new file mode 100644 index 00000000..06a8557c --- /dev/null +++ b/src/khoj/processor/conversation/helpers.py @@ -0,0 +1,126 @@ +from fastapi import HTTPException + +from khoj.database.adapters import ConversationAdapters, ais_user_subscribed +from khoj.database.models import ChatModelOptions, KhojUser +from khoj.processor.conversation.anthropic.anthropic_chat import ( + anthropic_send_message_to_model, +) +from khoj.processor.conversation.google.gemini_chat import gemini_send_message_to_model +from khoj.processor.conversation.offline.chat_model import send_message_to_model_offline +from khoj.processor.conversation.openai.gpt import send_message_to_model +from khoj.processor.conversation.utils import generate_chatml_messages_with_context +from khoj.utils import state +from khoj.utils.config import OfflineChatProcessorModel + + +async def send_message_to_model_wrapper( + message: str, + system_message: str = "", + response_type: str = "text", + chat_model_option: ChatModelOptions = None, + subscribed: bool = False, + uploaded_image_url: str = None, +): + conversation_config: ChatModelOptions = ( + chat_model_option or await ConversationAdapters.aget_default_conversation_config() + ) + + vision_available = conversation_config.vision_enabled + if not vision_available and uploaded_image_url: + vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() + if vision_enabled_config: + conversation_config = vision_enabled_config + vision_available = True + + chat_model = conversation_config.chat_model + max_tokens = ( + conversation_config.subscribed_max_prompt_size + if subscribed and conversation_config.subscribed_max_prompt_size + else conversation_config.max_prompt_size + ) + tokenizer = conversation_config.tokenizer + model_type = conversation_config.model_type + vision_available = conversation_config.vision_enabled + + if model_type == ChatModelOptions.ModelType.OFFLINE: + if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) + + loaded_model = state.offline_chat_processor_config.loaded_model + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model, + loaded_model=loaded_model, + tokenizer_name=tokenizer, + max_prompt_size=max_tokens, + vision_enabled=vision_available, + model_type=conversation_config.model_type, + ) + + return send_message_to_model_offline( + messages=truncated_messages, + loaded_model=loaded_model, + model=chat_model, + max_prompt_size=max_tokens, + streaming=False, + response_type=response_type, + ) + + elif model_type == ChatModelOptions.ModelType.OPENAI: + openai_chat_config = conversation_config.openai_config + api_key = openai_chat_config.api_key + api_base_url = openai_chat_config.api_base_url + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model, + max_prompt_size=max_tokens, + tokenizer_name=tokenizer, + vision_enabled=vision_available, + uploaded_image_url=uploaded_image_url, + model_type=conversation_config.model_type, + ) + + return send_message_to_model( + messages=truncated_messages, + api_key=api_key, + model=chat_model, + response_type=response_type, + api_base_url=api_base_url, + ) + elif model_type == ChatModelOptions.ModelType.ANTHROPIC: + api_key = conversation_config.openai_config.api_key + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model, + max_prompt_size=max_tokens, + tokenizer_name=tokenizer, + vision_enabled=vision_available, + uploaded_image_url=uploaded_image_url, + model_type=conversation_config.model_type, + ) + + return anthropic_send_message_to_model( + messages=truncated_messages, + api_key=api_key, + model=chat_model, + ) + elif model_type == ChatModelOptions.ModelType.GOOGLE: + api_key = conversation_config.openai_config.api_key + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model, + max_prompt_size=max_tokens, + tokenizer_name=tokenizer, + vision_enabled=vision_available, + uploaded_image_url=uploaded_image_url, + ) + + return gemini_send_message_to_model( + messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type + ) + else: + raise HTTPException(status_code=500, detail="Invalid conversation config") diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 9a2ba230..00ef56d9 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -2,6 +2,7 @@ import logging import math import queue from datetime import datetime +from enum import Enum from time import perf_counter from typing import Any, Dict, List, Optional @@ -10,7 +11,7 @@ from langchain.schema import ChatMessage from llama_cpp.llama import Llama from transformers import AutoTokenizer -from khoj.database.adapters import ConversationAdapters +from khoj.database.adapters import ConversationAdapters, ais_user_subscribed from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.utils import state @@ -75,6 +76,26 @@ class ThreadedGenerator: self.queue.put(StopIteration) +def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: + chat_history = "" + for chat in conversation_history.get("chat", [])[-n:]: + if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]: + chat_history += f"User: {chat['intent']['query']}\n" + chat_history += f"{agent_name}: {chat['message']}\n" + elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): + chat_history += f"User: {chat['intent']['query']}\n" + chat_history += f"{agent_name}: [generated image redacted for space]\n" + return chat_history + + +class ChatEvent(Enum): + START_LLM_RESPONSE = "start_llm_response" + END_LLM_RESPONSE = "end_llm_response" + MESSAGE = "message" + REFERENCES = "references" + STATUS = "status" + + def message_to_log( user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[] ): diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py new file mode 100644 index 00000000..2fd2e8d2 --- /dev/null +++ b/src/khoj/processor/tools/run_code.py @@ -0,0 +1,122 @@ +import asyncio +import datetime +import json +import logging +from typing import Any, Callable, List, Optional + +import aiohttp + +from khoj.database.adapters import ais_user_subscribed +from khoj.database.models import Agent, KhojUser +from khoj.processor.conversation import prompts +from khoj.processor.conversation.helpers import send_message_to_model_wrapper +from khoj.processor.conversation.utils import ( + ChatEvent, + construct_chat_history, + remove_json_codeblock, +) +from khoj.utils.helpers import timer +from khoj.utils.rawconfig import LocationData + +logger = logging.getLogger(__name__) + + +async def run_code( + query: str, + conversation_history: dict, + location_data: LocationData, + user: KhojUser, + send_status_func: Optional[Callable] = None, + uploaded_image_url: str = None, + agent: Agent = None, + sandbox_url: str = "http://localhost:8080", +): + # Generate Code + if send_status_func: + async for event in send_status_func(f"**Generate code snippets** for {query}"): + yield {ChatEvent.STATUS: event} + try: + with timer("Chat actor: Generate programs to execute", logger): + codes = await generate_python_code( + query, conversation_history, location_data, user, uploaded_image_url, agent + ) + except Exception as e: + raise ValueError(f"Failed to generate code for {query} with error: {e}") + + # Run Code + if send_status_func: + async for event in send_status_func(f"**Running {len(codes)} code snippets**"): + yield {ChatEvent.STATUS: event} + try: + tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes] + with timer("Chat actor: Execute generated programs", logger): + results = await asyncio.gather(*tasks) + for result in results: + code = result.pop("code") + logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--") + yield {query: {"code": code, "results": result}} + except Exception as e: + raise ValueError(f"Failed to run code for {query} with error: {e}") + + +async def generate_python_code( + q: str, + conversation_history: dict, + location_data: LocationData, + user: KhojUser, + uploaded_image_url: str = None, + agent: Agent = None, +) -> List[str]: + location = f"{location_data}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" + subscribed = await ais_user_subscribed(user) + chat_history = construct_chat_history(conversation_history) + + utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d") + personality_context = ( + prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" + ) + + code_generation_prompt = prompts.python_code_generation_prompt.format( + current_date=utc_date, + query=q, + chat_history=chat_history, + location=location, + username=username, + personality_context=personality_context, + ) + + response = await send_message_to_model_wrapper( + code_generation_prompt, + uploaded_image_url=uploaded_image_url, + response_type="json_object", + subscribed=subscribed, + ) + + # Validate that the response is a non-empty, JSON-serializable list + response = response.strip() + response = remove_json_codeblock(response) + response = json.loads(response) + codes = [code.strip() for code in response["codes"] if code.strip()] + + if not isinstance(codes, list) or not codes or len(codes) == 0: + raise ValueError + return codes + + +async def execute_sandboxed_python(code: str, sandbox_url: str = "http://localhost:8080") -> dict[str, Any]: + """ + Takes code to run as a string and calls the terrarium API to execute it. + Returns the result of the code execution as a dictionary. + """ + headers = {"Content-Type": "application/json"} + data = {"code": code} + + async with aiohttp.ClientSession() as session: + async with session.post(sandbox_url, json=data, headers=headers) as response: + if response.status == 200: + result: dict[str, Any] = await response.json() + result["code"] = code + return result + else: + return {"code": code, "success": False, "std_err": f"Failed to execute code with {response.status}"} diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index cdf16bd9..82d6f351 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -19,7 +19,6 @@ from khoj.database.adapters import ( AgentAdapters, ConversationAdapters, EntryAdapters, - FileObjectAdapters, PublicConversationAdapters, aget_user_name, ) @@ -29,6 +28,7 @@ from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.image.generate import text_to_image from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.tools.online_search import read_webpages, search_online +from khoj.processor.tools.run_code import run_code from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( ApiUserRateLimiter, @@ -46,7 +46,6 @@ from khoj.routers.helpers import ( is_query_empty, is_ready_to_chat, read_chat_stream, - run_code, update_telemetry_state, validate_conversation_config, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index f17fe1f5..3b97e694 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -24,7 +24,6 @@ from typing import ( ) from urllib.parse import parse_qs, quote, urljoin, urlparse -import aiohttp import cron_descriptor import pytz import requests @@ -79,13 +78,16 @@ from khoj.processor.conversation.google.gemini_chat import ( converse_gemini, gemini_send_message_to_model, ) +from khoj.processor.conversation.helpers import send_message_to_model_wrapper from khoj.processor.conversation.offline.chat_model import ( converse_offline, send_message_to_model_offline, ) from khoj.processor.conversation.openai.gpt import converse, send_message_to_model from khoj.processor.conversation.utils import ( + ChatEvent, ThreadedGenerator, + construct_chat_history, generate_chatml_messages_with_context, remove_json_codeblock, save_to_conversation_log, @@ -208,18 +210,6 @@ def get_next_url(request: Request) -> str: return urljoin(str(request.base_url).rstrip("/"), next_path) -def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: - chat_history = "" - for chat in conversation_history.get("chat", [])[-n:]: - if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]: - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: {chat['message']}\n" - elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: [generated image redacted for space]\n" - return chat_history - - def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand: if query.startswith("/notes"): return ConversationCommand.Notes @@ -520,103 +510,6 @@ async def generate_online_subqueries( return [q] -async def run_code( - query: str, - conversation_history: dict, - location_data: LocationData, - user: KhojUser, - send_status_func: Optional[Callable] = None, - uploaded_image_url: str = None, - agent: Agent = None, - sandbox_url: str = "http://localhost:8080", -): - # Generate Code - if send_status_func: - async for event in send_status_func(f"**Generate code snippets** for {query}"): - yield {ChatEvent.STATUS: event} - try: - with timer("Chat actor: Generate programs to execute", logger): - codes = await generate_python_code( - query, conversation_history, location_data, user, uploaded_image_url, agent - ) - except Exception as e: - raise ValueError(f"Failed to generate code for {query} with error: {e}") - - # Run Code - if send_status_func: - async for event in send_status_func(f"**Running {len(codes)} code snippets**"): - yield {ChatEvent.STATUS: event} - try: - tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes] - with timer("Chat actor: Execute generated programs", logger): - results = await asyncio.gather(*tasks) - for result in results: - code = result.pop("code") - logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--") - yield {query: {"code": code, "results": result}} - except Exception as e: - raise ValueError(f"Failed to run code for {query} with error: {e}") - - -async def generate_python_code( - q: str, - conversation_history: dict, - location_data: LocationData, - user: KhojUser, - uploaded_image_url: str = None, - agent: Agent = None, -) -> List[str]: - location = f"{location_data}" if location_data else "Unknown" - username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" - chat_history = construct_chat_history(conversation_history) - - utc_date = datetime.utcnow().strftime("%Y-%m-%d") - personality_context = ( - prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" - ) - - code_generation_prompt = prompts.python_code_generation_prompt.format( - current_date=utc_date, - query=q, - chat_history=chat_history, - location=location, - username=username, - personality_context=personality_context, - ) - - response = await send_message_to_model_wrapper( - code_generation_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user - ) - - # Validate that the response is a non-empty, JSON-serializable list - response = response.strip() - response = remove_json_codeblock(response) - response = json.loads(response) - codes = [code.strip() for code in response["codes"] if code.strip()] - - if not isinstance(codes, list) or not codes or len(codes) == 0: - raise ValueError - return codes - - -async def execute_sandboxed_python(code: str, sandbox_url: str = "http://localhost:8080") -> dict[str, Any]: - """ - Takes code to run as a string and calls the terrarium API to execute it. - Returns the result of the code execution as a dictionary. - """ - headers = {"Content-Type": "application/json"} - data = {"code": code} - - async with aiohttp.ClientSession() as session: - async with session.post(sandbox_url, json=data, headers=headers) as response: - if response.status == 200: - result: dict[str, Any] = await response.json() - result["code"] = code - return result - else: - return {"code": code, "success": False, "std_err": f"Failed to execute code with {response.status}"} - - async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. @@ -837,119 +730,6 @@ async def generate_better_image_prompt( return response -async def send_message_to_model_wrapper( - message: str, - system_message: str = "", - response_type: str = "text", - chat_model_option: ChatModelOptions = None, - subscribed: bool = False, - uploaded_image_url: str = None, -): - conversation_config: ChatModelOptions = ( - chat_model_option or await ConversationAdapters.aget_default_conversation_config() - ) - - vision_available = conversation_config.vision_enabled - if not vision_available and uploaded_image_url: - vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() - if vision_enabled_config: - conversation_config = vision_enabled_config - vision_available = True - - chat_model = conversation_config.chat_model - max_tokens = ( - conversation_config.subscribed_max_prompt_size - if subscribed and conversation_config.subscribed_max_prompt_size - else conversation_config.max_prompt_size - ) - tokenizer = conversation_config.tokenizer - model_type = conversation_config.model_type - vision_available = conversation_config.vision_enabled - - if model_type == ChatModelOptions.ModelType.OFFLINE: - if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) - - loaded_model = state.offline_chat_processor_config.loaded_model - truncated_messages = generate_chatml_messages_with_context( - user_message=message, - system_message=system_message, - model_name=chat_model, - loaded_model=loaded_model, - tokenizer_name=tokenizer, - max_prompt_size=max_tokens, - vision_enabled=vision_available, - model_type=conversation_config.model_type, - ) - - return send_message_to_model_offline( - messages=truncated_messages, - loaded_model=loaded_model, - model=chat_model, - max_prompt_size=max_tokens, - streaming=False, - response_type=response_type, - ) - - elif model_type == ChatModelOptions.ModelType.OPENAI: - openai_chat_config = conversation_config.openai_config - api_key = openai_chat_config.api_key - api_base_url = openai_chat_config.api_base_url - truncated_messages = generate_chatml_messages_with_context( - user_message=message, - system_message=system_message, - model_name=chat_model, - max_prompt_size=max_tokens, - tokenizer_name=tokenizer, - vision_enabled=vision_available, - uploaded_image_url=uploaded_image_url, - model_type=conversation_config.model_type, - ) - - return send_message_to_model( - messages=truncated_messages, - api_key=api_key, - model=chat_model, - response_type=response_type, - api_base_url=api_base_url, - ) - elif model_type == ChatModelOptions.ModelType.ANTHROPIC: - api_key = conversation_config.openai_config.api_key - truncated_messages = generate_chatml_messages_with_context( - user_message=message, - system_message=system_message, - model_name=chat_model, - max_prompt_size=max_tokens, - tokenizer_name=tokenizer, - vision_enabled=vision_available, - uploaded_image_url=uploaded_image_url, - model_type=conversation_config.model_type, - ) - - return anthropic_send_message_to_model( - messages=truncated_messages, - api_key=api_key, - model=chat_model, - ) - elif model_type == ChatModelOptions.ModelType.GOOGLE: - api_key = conversation_config.openai_config.api_key - truncated_messages = generate_chatml_messages_with_context( - user_message=message, - system_message=system_message, - model_name=chat_model, - max_prompt_size=max_tokens, - tokenizer_name=tokenizer, - vision_enabled=vision_available, - uploaded_image_url=uploaded_image_url, - ) - - return gemini_send_message_to_model( - messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type - ) - else: - raise HTTPException(status_code=500, detail="Invalid conversation config") - - def send_message_to_model_wrapper_sync( message: str, system_message: str = "", @@ -1540,14 +1320,6 @@ Manage your automations [here](/automations). """.strip() -class ChatEvent(Enum): - START_LLM_RESPONSE = "start_llm_response" - END_LLM_RESPONSE = "end_llm_response" - MESSAGE = "message" - REFERENCES = "references" - STATUS = "status" - - class MessageProcessor: def __init__(self): self.references = {} From b373073f47f8d8dc3d6f96ba5520ed5616632e80 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 9 Oct 2024 17:31:50 -0700 Subject: [PATCH 22/88] Show executed code in web app chat message references --- src/interface/web/app/chat/page.tsx | 12 +- src/interface/web/app/common/chatFunctions.ts | 20 +++- .../components/chatHistory/chatHistory.tsx | 3 + .../components/chatMessage/chatMessage.tsx | 24 ++++ .../referencePanel/referencePanel.tsx | 108 +++++++++++++++++- src/interface/web/app/factchecker/page.tsx | 4 + src/interface/web/app/share/chat/page.tsx | 1 + 7 files changed, 159 insertions(+), 13 deletions(-) diff --git a/src/interface/web/app/chat/page.tsx b/src/interface/web/app/chat/page.tsx index 7d87fd81..156658e7 100644 --- a/src/interface/web/app/chat/page.tsx +++ b/src/interface/web/app/chat/page.tsx @@ -12,7 +12,12 @@ import { processMessageChunk } from "../common/chatFunctions"; import "katex/dist/katex.min.css"; -import { Context, OnlineContext, StreamMessage } from "../components/chatMessage/chatMessage"; +import { + CodeContext, + Context, + OnlineContext, + StreamMessage, +} from "../components/chatMessage/chatMessage"; import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../common/utils"; import ChatInputArea, { ChatOptions } from "../components/chatInputArea/chatInputArea"; import { useAuthenticatedData } from "../common/auth"; @@ -167,6 +172,7 @@ export default function Chat() { trainOfThought: [], context: [], onlineContext: {}, + codeContext: {}, completed: false, timestamp: new Date().toISOString(), rawQuery: queryToProcess || "", @@ -195,6 +201,7 @@ export default function Chat() { // Track context used for chat response let context: Context[] = []; let onlineContext: OnlineContext = {}; + let codeContext: CodeContext = {}; while (true) { const { done, value } = await reader.read(); @@ -221,11 +228,12 @@ export default function Chat() { } // Track context used for chat response. References are rendered at the end of the chat - ({ context, onlineContext } = processMessageChunk( + ({ context, onlineContext, codeContext } = processMessageChunk( event, currentMessage, context, onlineContext, + codeContext, )); setMessages([...messages]); diff --git a/src/interface/web/app/common/chatFunctions.ts b/src/interface/web/app/common/chatFunctions.ts index 6d7c9bc1..b1bc60ed 100644 --- a/src/interface/web/app/common/chatFunctions.ts +++ b/src/interface/web/app/common/chatFunctions.ts @@ -1,13 +1,20 @@ -import { Context, OnlineContext, StreamMessage } from "../components/chatMessage/chatMessage"; +import { + CodeContext, + Context, + OnlineContext, + StreamMessage, +} from "../components/chatMessage/chatMessage"; export interface RawReferenceData { context?: Context[]; onlineContext?: OnlineContext; + codeContext?: CodeContext; } export interface ResponseWithReferences { context?: Context[]; online?: OnlineContext; + codeContext?: CodeContext; response?: string; } @@ -63,10 +70,11 @@ export function processMessageChunk( currentMessage: StreamMessage, context: Context[] = [], onlineContext: OnlineContext = {}, -): { context: Context[]; onlineContext: OnlineContext } { + codeContext: CodeContext = {}, +): { context: Context[]; onlineContext: OnlineContext; codeContext: CodeContext } { const chunk = convertMessageChunkToJson(rawChunk); - if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext }; + if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext }; if (chunk.type === "status") { console.log(`status: ${chunk.data}`); @@ -77,7 +85,8 @@ export function processMessageChunk( if (references.context) context = references.context; if (references.onlineContext) onlineContext = references.onlineContext; - return { context, onlineContext }; + if (references.codeContext) codeContext = references.codeContext; + return { context, onlineContext, codeContext }; } else if (chunk.type === "message") { const chunkData = chunk.data; if (chunkData !== null && typeof chunkData === "object") { @@ -102,13 +111,14 @@ export function processMessageChunk( console.log(`Completed streaming: ${new Date()}`); // Append any references after all the data has been streamed + if (codeContext) currentMessage.codeContext = codeContext; if (onlineContext) currentMessage.onlineContext = onlineContext; if (context) currentMessage.context = context; // Mark current message streaming as completed currentMessage.completed = true; } - return { context, onlineContext }; + return { context, onlineContext, codeContext }; } export function handleImageResponse(imageJson: any, liveStream: boolean): ResponseWithReferences { diff --git a/src/interface/web/app/components/chatHistory/chatHistory.tsx b/src/interface/web/app/components/chatHistory/chatHistory.tsx index 1a7c90c0..177ff56d 100644 --- a/src/interface/web/app/components/chatHistory/chatHistory.tsx +++ b/src/interface/web/app/components/chatHistory/chatHistory.tsx @@ -295,6 +295,7 @@ export default function ChatHistory(props: ChatHistoryProps) { message: message.rawQuery, context: [], onlineContext: {}, + codeContext: {}, created: message.timestamp, by: "you", automationId: "", @@ -318,6 +319,7 @@ export default function ChatHistory(props: ChatHistoryProps) { message: message.rawResponse, context: message.context, onlineContext: message.onlineContext, + codeContext: message.codeContext, created: message.timestamp, by: "khoj", automationId: "", @@ -338,6 +340,7 @@ export default function ChatHistory(props: ChatHistoryProps) { message: props.pendingMessage, context: [], onlineContext: {}, + codeContext: {}, created: new Date().getTime().toString(), by: "you", automationId: "", diff --git a/src/interface/web/app/components/chatMessage/chatMessage.tsx b/src/interface/web/app/components/chatMessage/chatMessage.tsx index 23371512..1836bfde 100644 --- a/src/interface/web/app/components/chatMessage/chatMessage.tsx +++ b/src/interface/web/app/components/chatMessage/chatMessage.tsx @@ -97,6 +97,26 @@ export interface OnlineContextData { peopleAlsoAsk: PeopleAlsoAsk[]; } +export interface CodeContext { + [key: string]: CodeContextData; +} + +export interface CodeContextData { + code: string; + results: { + success: boolean; + output_files: CodeContextFile[]; + std_out: string; + std_err: string; + code_runtime: number; + }; +} + +export interface CodeContextFile { + filename: string; + b64_data: string; +} + interface Intent { type: string; query: string; @@ -111,6 +131,7 @@ export interface SingleChatMessage { created: string; context: Context[]; onlineContext: OnlineContext; + codeContext: CodeContext; rawQuery?: string; intent?: Intent; agent?: AgentData; @@ -122,6 +143,7 @@ export interface StreamMessage { trainOfThought: string[]; context: Context[]; onlineContext: OnlineContext; + codeContext: CodeContext; completed: boolean; rawQuery: string; timestamp: string; @@ -539,6 +561,7 @@ const ChatMessage = forwardRef((props, ref) => const allReferences = constructAllReferences( props.chatMessage.context, props.chatMessage.onlineContext, + props.chatMessage.codeContext, ); return ( @@ -560,6 +583,7 @@ const ChatMessage = forwardRef((props, ref) => isMobileWidth={props.isMobileWidth} notesReferenceCardData={allReferences.notesReferenceCardData} onlineReferenceCardData={allReferences.onlineReferenceCardData} + codeReferenceCardData={allReferences.codeReferenceCardData} />
diff --git a/src/interface/web/app/components/referencePanel/referencePanel.tsx b/src/interface/web/app/components/referencePanel/referencePanel.tsx index 8b808505..899f0b5b 100644 --- a/src/interface/web/app/components/referencePanel/referencePanel.tsx +++ b/src/interface/web/app/components/referencePanel/referencePanel.tsx @@ -11,7 +11,7 @@ const md = new markdownIt({ typographer: true, }); -import { Context, WebPage, OnlineContext } from "../chatMessage/chatMessage"; +import { Context, WebPage, OnlineContext, CodeContext } from "../chatMessage/chatMessage"; import { Card } from "@/components/ui/card"; import { @@ -94,11 +94,67 @@ function NotesContextReferenceCard(props: NotesContextReferenceCardProps) { ); } +interface CodeContextReferenceCardProps { + code: string; + output: string; + error: string; + showFullContent: boolean; +} + +function CodeContextReferenceCard(props: CodeContextReferenceCardProps) { + const fileIcon = getIconFromFilename(".py", "w-6 h-6 text-muted-foreground inline-flex mr-2"); + const snippet = DOMPurify.sanitize(props.code); + const [isHovering, setIsHovering] = useState(false); + + return ( + <> + + + setIsHovering(true)} + onMouseLeave={() => setIsHovering(false)} + className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words text-balance rounded-lg p-2 bg-muted border-none`} + > +

+ {fileIcon} + Code +

+

+ {snippet} +

+
+
+ + +

+ {fileIcon} + Code +

+

{snippet}

+
+
+
+ + ); +} + export interface ReferencePanelData { notesReferenceCardData: NotesContextReferenceData[]; onlineReferenceCardData: OnlineReferenceData[]; } +export interface CodeReferenceData { + code: string; + output: string; + error: string; +} + interface OnlineReferenceData { title: string; description: string; @@ -214,9 +270,27 @@ function GenericOnlineReferenceCard(props: OnlineReferenceCardProps) { ); } -export function constructAllReferences(contextData: Context[], onlineData: OnlineContext) { +export function constructAllReferences( + contextData: Context[], + onlineData: OnlineContext, + codeContext: CodeContext, +) { const onlineReferences: OnlineReferenceData[] = []; const contextReferences: NotesContextReferenceData[] = []; + const codeReferences: CodeReferenceData[] = []; + + if (codeContext) { + for (const [key, value] of Object.entries(codeContext)) { + if (!value.results) { + continue; + } + codeReferences.push({ + code: value.code, + output: value.results.std_out, + error: value.results.std_err, + }); + } + } if (onlineData) { let localOnlineReferences = []; @@ -298,12 +372,14 @@ export function constructAllReferences(contextData: Context[], onlineData: Onlin return { notesReferenceCardData: contextReferences, onlineReferenceCardData: onlineReferences, + codeReferenceCardData: codeReferences, }; } export interface TeaserReferenceSectionProps { notesReferenceCardData: NotesContextReferenceData[]; onlineReferenceCardData: OnlineReferenceData[]; + codeReferenceCardData: CodeReferenceData[]; isMobileWidth: boolean; } @@ -315,16 +391,27 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) { }, [props.isMobileWidth]); const notesDataToShow = props.notesReferenceCardData.slice(0, numTeaserSlots); + const codeDataToShow = props.codeReferenceCardData.slice( + 0, + numTeaserSlots - notesDataToShow.length, + ); const onlineDataToShow = - notesDataToShow.length < numTeaserSlots - ? props.onlineReferenceCardData.slice(0, numTeaserSlots - notesDataToShow.length) + notesDataToShow.length + codeDataToShow.length < numTeaserSlots + ? props.onlineReferenceCardData.slice( + 0, + numTeaserSlots - codeDataToShow.length - notesDataToShow.length, + ) : []; const shouldShowShowMoreButton = - props.notesReferenceCardData.length > 0 || props.onlineReferenceCardData.length > 0; + props.notesReferenceCardData.length > 0 || + props.codeReferenceCardData.length > 0 || + props.onlineReferenceCardData.length > 0; const numReferences = - props.notesReferenceCardData.length + props.onlineReferenceCardData.length; + props.notesReferenceCardData.length + + props.codeReferenceCardData.length + + props.onlineReferenceCardData.length; if (numReferences === 0) { return null; @@ -346,6 +433,15 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) { /> ); })} + {codeDataToShow.map((code, index) => { + return ( + + ); + })} {onlineDataToShow.map((online, index) => { return ( @@ -622,6 +625,7 @@ export default function FactChecker() { context: [], created: new Date().toISOString(), onlineContext: {}, + codeContext: {}, }} isMobileWidth={isMobileWidth} /> diff --git a/src/interface/web/app/share/chat/page.tsx b/src/interface/web/app/share/chat/page.tsx index 9bc5f12d..9b13aafa 100644 --- a/src/interface/web/app/share/chat/page.tsx +++ b/src/interface/web/app/share/chat/page.tsx @@ -164,6 +164,7 @@ export default function SharedChat() { trainOfThought: [], context: [], onlineContext: {}, + codeContext: {}, completed: false, timestamp: new Date().toISOString(), rawQuery: queryToProcess || "", From 8d33c764b76209299e9498083a612c97f6bb145c Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 9 Oct 2024 22:03:05 -0700 Subject: [PATCH 23/88] Allow iterative chat director to use python interpreter as a tool --- src/khoj/routers/api_chat.py | 4 +++- src/khoj/routers/research.py | 33 +++++++++++++++++++++++++++++++-- src/khoj/utils/helpers.py | 1 + 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 82d6f351..a87e4f24 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -696,6 +696,7 @@ async def chat( pending_research = True researched_results = "" online_results: Dict = dict() + code_results: Dict = dict() ## Extract Document References compiled_references: List[Any] = [] inferred_queries: List[Any] = [] @@ -721,7 +722,8 @@ async def chat( pending_research = False if research_result.onlineContext: online_results.update(research_result.onlineContext) - + if research_result.codeContext: + code_results.update(research_result.codeContext) if research_result.context: compiled_references.extend(research_result.context) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 2974b511..9a5e3169 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -10,6 +10,7 @@ from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import remove_json_codeblock from khoj.processor.tools.online_search import read_webpages, search_online +from khoj.processor.tools.run_code import run_code from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( ChatEvent, @@ -36,12 +37,14 @@ class InformationCollectionIteration: query: str, context: Dict[str, Dict] = None, onlineContext: dict = None, + codeContext: dict = None, summarizedResult: str = None, ): self.data_source = data_source self.query = query self.context = context self.onlineContext = onlineContext + self.codeContext = codeContext self.summarizedResult = summarizedResult @@ -160,7 +163,7 @@ async def execute_information_collection( previous_iterations: List[InformationCollectionIteration] = [] while current_iteration < MAX_ITERATIONS: online_results: Dict = dict() - + code_results: Dict = dict() compiled_references: List[Any] = [] inferred_queries: List[Any] = [] @@ -260,6 +263,30 @@ async def execute_information_collection( except Exception as e: logger.error(f"Error reading webpages: {e}", exc_info=True) + elif this_iteration.data_source == ConversationCommand.Code: + try: + async for result in run_code( + this_iteration.query, + conversation_history, + location, + user, + send_status_func, + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + code_results: Dict[str, Dict] = result # type: ignore + this_iteration.codeContext = code_results + async for result in send_status_func(f"**Ran code snippets**: {len(this_iteration.codeContext)}"): + yield result + except ValueError as e: + logger.warning( + f"Failed to use code tool: {e}. Attempting to respond without code results", + exc_info=True, + ) + # TODO: Fix summarize later # elif this_iteration.data_source == ConversationCommand.Summarize: # response_log = "" @@ -306,12 +333,14 @@ async def execute_information_collection( current_iteration += 1 - if compiled_references or online_results: + if compiled_references or online_results or code_results: results_data = f"**Results**:\n" if compiled_references: results_data += f"**Document References**: {compiled_references}\n" if online_results: results_data += f"**Online Results**: {online_results}\n" + if code_results: + results_data += f"**Code Results**: {code_results}\n" # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) this_iteration.summarizedResult = results_data diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 0e0193a9..d3978fa4 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -352,6 +352,7 @@ function_calling_description_for_llm = { ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.", ConversationCommand.Online: "To search for the latest, up-to-date information from the internet.", ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.", + ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.", } mode_descriptions_for_llm = { From 536422a40c8f75d5d49c4222ed8cb465ce4135d9 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 23:54:11 -0700 Subject: [PATCH 24/88] Include code snippets in the reference panel --- .../app/components/referencePanel/referencePanel.tsx | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/interface/web/app/components/referencePanel/referencePanel.tsx b/src/interface/web/app/components/referencePanel/referencePanel.tsx index 899f0b5b..785d05d8 100644 --- a/src/interface/web/app/components/referencePanel/referencePanel.tsx +++ b/src/interface/web/app/components/referencePanel/referencePanel.tsx @@ -455,6 +455,7 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) { )}
@@ -465,6 +466,7 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) { interface ReferencePanelDataProps { notesReferenceCardData: NotesContextReferenceData[]; onlineReferenceCardData: OnlineReferenceData[]; + codeReferenceCardData: CodeReferenceData[]; } export default function ReferencePanel(props: ReferencePanelDataProps) { @@ -502,6 +504,15 @@ export default function ReferencePanel(props: ReferencePanelDataProps) { /> ); })} + {props.codeReferenceCardData.map((code, index) => { + return ( + + ); + })} From e69a8382f282bf104936cb77dc5975f80720721f Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 9 Oct 2024 23:56:57 -0700 Subject: [PATCH 25/88] Add a code icon for code-related train of thought --- src/interface/web/app/components/chatMessage/chatMessage.tsx | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/interface/web/app/components/chatMessage/chatMessage.tsx b/src/interface/web/app/components/chatMessage/chatMessage.tsx index 1836bfde..b75c852e 100644 --- a/src/interface/web/app/components/chatMessage/chatMessage.tsx +++ b/src/interface/web/app/components/chatMessage/chatMessage.tsx @@ -26,6 +26,7 @@ import { Palette, ClipboardText, Check, + Code, } from "@phosphor-icons/react"; import DOMPurify from "dompurify"; @@ -278,6 +279,10 @@ function chooseIconFromHeader(header: string, iconColor: string) { return ; } + if (compareHeader.includes("code")) { + return ; + } + return ; } From 2dc5804571bdf27cab76797939f892463b40b31c Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 00:27:27 -0700 Subject: [PATCH 26/88] Extract defilter query into conversation utils for reuse --- src/khoj/processor/conversation/utils.py | 11 +++++++++++ src/khoj/routers/api.py | 5 ++--- src/khoj/routers/api_chat.py | 4 ++-- src/khoj/search_filter/date_filter.py | 2 -- src/khoj/search_filter/file_filter.py | 3 +-- 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 00ef56d9..339024ae 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -14,6 +14,9 @@ from transformers import AutoTokenizer from khoj.database.adapters import ConversationAdapters, ais_user_subscribed from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens +from khoj.search_filter.date_filter import DateFilter +from khoj.search_filter.file_filter import FileFilter +from khoj.search_filter.word_filter import WordFilter from khoj.utils import state from khoj.utils.helpers import is_none_or_empty, merge_dicts @@ -320,3 +323,11 @@ def reciprocal_conversation_to_chatml(message_pair): def remove_json_codeblock(response: str): """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models""" return response.removeprefix("```json").removesuffix("```") + + +def defilter_query(query: str): + """Remove any query filters in query""" + defiltered_query = query + for filter in [DateFilter(), WordFilter(), FileFilter()]: + defiltered_query = filter.defilter(defiltered_query) + return defiltered_query diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 11ab1112..46fdfa43 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -42,6 +42,7 @@ from khoj.processor.conversation.offline.chat_model import extract_questions_off from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.whisper import transcribe_audio +from khoj.processor.conversation.utils import defilter_query from khoj.routers.helpers import ( ApiUserRateLimiter, ChatEvent, @@ -375,9 +376,7 @@ async def extract_references_and_questions( return # Extract filter terms from user message - defiltered_query = q - for filter in [DateFilter(), WordFilter(), FileFilter()]: - defiltered_query = filter.defilter(defiltered_query) + defiltered_query = defilter_query(q) filters_in_query = q.replace(defiltered_query, "").strip() conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index a87e4f24..60d414de 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -24,7 +24,7 @@ from khoj.database.adapters import ( ) from khoj.database.models import Agent, KhojUser from khoj.processor.conversation.prompts import help_message, no_entries_found -from khoj.processor.conversation.utils import save_to_conversation_log +from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log from khoj.processor.image.generate import text_to_image from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.tools.online_search import read_webpages, search_online @@ -700,7 +700,7 @@ async def chat( ## Extract Document References compiled_references: List[Any] = [] inferred_queries: List[Any] = [] - defiltered_query: str = None + defiltered_query = defilter_query(q) if conversation_commands == [ConversationCommand.Default] or is_automated_task: async for research_result in execute_information_collection( diff --git a/src/khoj/search_filter/date_filter.py b/src/khoj/search_filter/date_filter.py index 62643e15..91967799 100644 --- a/src/khoj/search_filter/date_filter.py +++ b/src/khoj/search_filter/date_filter.py @@ -7,8 +7,6 @@ from math import inf from typing import List, Tuple import dateparser as dtparse -from dateparser.search import search_dates -from dateparser_data.settings import default_parsers from dateutil.relativedelta import relativedelta from khoj.search_filter.base_filter import BaseFilter diff --git a/src/khoj/search_filter/file_filter.py b/src/khoj/search_filter/file_filter.py index 9883ea70..85295beb 100644 --- a/src/khoj/search_filter/file_filter.py +++ b/src/khoj/search_filter/file_filter.py @@ -1,11 +1,10 @@ -import fnmatch import logging import re from collections import defaultdict from typing import List from khoj.search_filter.base_filter import BaseFilter -from khoj.utils.helpers import LRU, timer +from khoj.utils.helpers import LRU logger = logging.getLogger(__name__) From 9e7025b33087a10464af675fbf1f90fe40fcfaf7 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 00:28:51 -0700 Subject: [PATCH 27/88] Set python interpret sandbox url via environment variable --- src/khoj/processor/tools/run_code.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 2fd2e8d2..681c5f94 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -2,6 +2,7 @@ import asyncio import datetime import json import logging +import os from typing import Any, Callable, List, Optional import aiohttp @@ -21,6 +22,9 @@ from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) +SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080") + + async def run_code( query: str, conversation_history: dict, @@ -29,7 +33,7 @@ async def run_code( send_status_func: Optional[Callable] = None, uploaded_image_url: str = None, agent: Agent = None, - sandbox_url: str = "http://localhost:8080", + sandbox_url: str = SANDBOX_URL, ): # Generate Code if send_status_func: @@ -104,7 +108,7 @@ async def generate_python_code( return codes -async def execute_sandboxed_python(code: str, sandbox_url: str = "http://localhost:8080") -> dict[str, Any]: +async def execute_sandboxed_python(code: str, sandbox_url: str = SANDBOX_URL) -> dict[str, Any]: """ Takes code to run as a string and calls the terrarium API to execute it. Returns the result of the code execution as a dictionary. From 61df1d5db87551e6a327949d283094ab77d8cb79 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 00:59:25 -0700 Subject: [PATCH 28/88] Pass previous iteration results to code interpreter chat actors This improves the code interpreter chat actors abilitiy to generate code with data collected during the previous iterations --- src/khoj/processor/conversation/prompts.py | 5 ++- src/khoj/processor/conversation/utils.py | 34 ++++++++++++++++++ src/khoj/processor/tools/run_code.py | 5 ++- src/khoj/routers/api_chat.py | 2 ++ src/khoj/routers/research.py | 42 +++++----------------- 5 files changed, 53 insertions(+), 35 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 23788cab..a82dc18f 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -753,7 +753,10 @@ For example: {{"codes": ["print('Hello, World!')", "print('Goodbye, World!')"]}} Now it's your turn to construct python programs to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else. -History: +Data from Previous Iterations: +{previous_iterations_history} + +Chat History: {chat_history} User: {query} diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 339024ae..b8960e0b 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -79,6 +79,40 @@ class ThreadedGenerator: self.queue.put(StopIteration) +class InformationCollectionIteration: + def __init__( + self, + data_source: str, + query: str, + context: Dict[str, Dict] = None, + onlineContext: dict = None, + codeContext: dict = None, + summarizedResult: str = None, + ): + self.data_source = data_source + self.query = query + self.context = context + self.onlineContext = onlineContext + self.codeContext = codeContext + self.summarizedResult = summarizedResult + + +def construct_iteration_history( + previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str +) -> str: + previous_iterations_history = "" + for idx, iteration in enumerate(previous_iterations): + iteration_data = previous_iteration_prompt.format( + query=iteration.query, + data_source=iteration.data_source, + summary=iteration.summarizedResult, + index=idx + 1, + ) + + previous_iterations_history += iteration_data + return previous_iterations_history + + def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: chat_history = "" for chat in conversation_history.get("chat", [])[-n:]: diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 681c5f94..384b993c 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -28,6 +28,7 @@ SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080") async def run_code( query: str, conversation_history: dict, + previous_iterations_history: str, location_data: LocationData, user: KhojUser, send_status_func: Optional[Callable] = None, @@ -42,7 +43,7 @@ async def run_code( try: with timer("Chat actor: Generate programs to execute", logger): codes = await generate_python_code( - query, conversation_history, location_data, user, uploaded_image_url, agent + query, conversation_history, previous_iterations_history, location_data, user, uploaded_image_url, agent ) except Exception as e: raise ValueError(f"Failed to generate code for {query} with error: {e}") @@ -66,6 +67,7 @@ async def run_code( async def generate_python_code( q: str, conversation_history: dict, + previous_iterations_history: str, location_data: LocationData, user: KhojUser, uploaded_image_url: str = None, @@ -85,6 +87,7 @@ async def generate_python_code( current_date=utc_date, query=q, chat_history=chat_history, + previous_iterations_history=previous_iterations_history, location=location, username=username, personality_context=personality_context, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 60d414de..2a9654ef 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -954,9 +954,11 @@ async def chat( ## Gather Code Results if ConversationCommand.Code in conversation_commands: try: + previous_iteration_history = "" async for result in run_code( defiltered_query, meta_log, + previous_iteration_history, location, user, partial(send_event, ChatEvent.STATUS), diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 9a5e3169..1ada9e7a 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -8,7 +8,11 @@ from fastapi import Request from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts -from khoj.processor.conversation.utils import remove_json_codeblock +from khoj.processor.conversation.utils import ( + InformationCollectionIteration, + construct_iteration_history, + remove_json_codeblock, +) from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.run_code import run_code from khoj.routers.api import extract_references_and_questions @@ -30,24 +34,6 @@ from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) -class InformationCollectionIteration: - def __init__( - self, - data_source: str, - query: str, - context: Dict[str, Dict] = None, - onlineContext: dict = None, - codeContext: dict = None, - summarizedResult: str = None, - ): - self.data_source = data_source - self.query = query - self.context = context - self.onlineContext = onlineContext - self.codeContext = codeContext - self.summarizedResult = summarizedResult - - async def apick_next_tool( query: str, conversation_history: dict, @@ -56,7 +42,7 @@ async def apick_next_tool( location: LocationData = None, user_name: str = None, agent: Agent = None, - previous_iterations: List[InformationCollectionIteration] = None, + previous_iterations_history: str = None, max_iterations: int = 5, ): """ @@ -75,17 +61,6 @@ async def apick_next_tool( chat_history = construct_chat_history(conversation_history) - previous_iterations_history = "" - for idx, iteration in enumerate(previous_iterations): - iteration_data = prompts.previous_iteration.format( - query=iteration.query, - data_source=iteration.data_source, - summary=iteration.summarizedResult, - index=idx + 1, - ) - - previous_iterations_history += iteration_data - if uploaded_image_url: query = f"[placeholder for user attached image]\n{query}" @@ -98,7 +73,6 @@ async def apick_next_tool( location_data = f"{location}" if location else "Unknown" username = prompts.user_name.format(name=user_name) if user_name else "" - # TODO Add current date/time to the query function_planning_prompt = prompts.plan_function_execution.format( query=query, tools=tool_options_str, @@ -166,6 +140,7 @@ async def execute_information_collection( code_results: Dict = dict() compiled_references: List[Any] = [] inferred_queries: List[Any] = [] + previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) result: str = "" @@ -177,7 +152,7 @@ async def execute_information_collection( location, user_name, agent, - previous_iterations, + previous_iterations_history, MAX_ITERATIONS, ) if this_iteration.data_source == ConversationCommand.Notes: @@ -268,6 +243,7 @@ async def execute_information_collection( async for result in run_code( this_iteration.query, conversation_history, + previous_iterations_history, location, user, send_status_func, From 5a699a52d28d33e3e10118be08260ea045a94930 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 02:08:18 -0700 Subject: [PATCH 29/88] Improve webpage summarization prompt to better extract links, excerpts This change allows the iterative director to dive deeper into its research as the data extracted contains relevant links from the webpage Previous summarization prompt didn't extract relevant links from the webpage which limited further explorations from webpages --- src/khoj/processor/conversation/prompts.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index a82dc18f..cc9cb2af 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -369,17 +369,14 @@ Assistant: ) system_prompt_extract_relevant_information = """ -As a professional analyst, create a comprehensive report of the most relevant information from a web page in response to a user's query. -The text provided is directly from within the web page. -The report you create should be multiple paragraphs, and it should represent the content of the website. -Tell the user exactly what the website says in response to their query, while adhering to these guidelines: +As a professional analyst, your job is to extract all pertinent information from webpages to help answer user's query. +You will be provided raw text directly from within the web page. +Adhere to these guidelines while extracting information from the provided webpages: -1. Answer the user's query as specifically as possible. Include many supporting details from the website. -2. Craft a report that is detailed, thorough, in-depth, and complex, while maintaining clarity. -3. Rely strictly on the provided text, without including external information. -4. Format the report in multiple paragraphs with a clear structure. -5. Be as specific as possible in your answer to the user's query. -6. Reproduce as much of the provided text as possible, while maintaining readability. +1. Extract all relevant text and links from the webpage that can assist with further research or answer the user's query. +2. Craft a comprehensive but compact report with all the necessary data from the website to generate an informed response. +3. Rely strictly on the provided text to generate your summary, without including external information. +4. Provide specific, important snippets from the webpage in your report. """.strip() extract_relevant_information = PromptTemplate.from_template( From 1e390325d2d668a44ac6caa273d82fc287475778 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 02:19:08 -0700 Subject: [PATCH 30/88] Let research chat director decide which webpage to read, if any Make webpages to read automatically on search_online configurable via a argument. Set it to default to 1, so other callers of the function are unaffected. But iterative chat director can still decide which, if any, webpages to read based on the online search it performs --- src/khoj/processor/tools/online_search.py | 6 +++--- src/khoj/routers/research.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 840d8a81..c2e051d6 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -46,8 +46,7 @@ OLOSTEP_QUERY_PARAMS = { "expandHtml": "False", } -# TODO: Should this be 0 to let advanced model decide which web pages to read? -MAX_WEBPAGES_TO_READ = 1 +DEFAULT_MAX_WEBPAGES_TO_READ = 1 async def search_online( @@ -58,6 +57,7 @@ async def search_online( subscribed: bool = False, send_status_func: Optional[Callable] = None, custom_filters: List[str] = [], + max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ, uploaded_image_url: str = None, agent: Agent = None, ): @@ -91,7 +91,7 @@ async def search_online( webpages = { (organic.get("link"), subquery, organic.get("content")) for subquery in response_dict - for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ] + for organic in response_dict[subquery].get("organic", [])[:max_webpages_to_read] if "answerBox" not in response_dict[subquery] } diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 1ada9e7a..de47419a 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -199,6 +199,7 @@ async def execute_information_collection( subscribed, send_status_func, [], + max_webpages_to_read=0, uploaded_image_url=uploaded_image_url, agent=agent, ): From 284c8c331b14218e2e125034aef9b6b5525e31f4 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 02:22:24 -0700 Subject: [PATCH 31/88] Increase default max iterations for research chat director to 5 --- src/khoj/routers/research.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index de47419a..6f479a77 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -133,7 +133,7 @@ async def execute_information_collection( file_filters: List[str] = [], ): current_iteration = 0 - MAX_ITERATIONS = 2 + MAX_ITERATIONS = 5 previous_iterations: List[InformationCollectionIteration] = [] while current_iteration < MAX_ITERATIONS: online_results: Dict = dict() From 0eacc0b2b094cacfec7637a408587adde897d131 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 03:54:31 -0700 Subject: [PATCH 32/88] Use consistent name for user, planner to not miss current user query Previously Khoj would start answering the previous query. This maybe because the prompt uses User for prompt in chat history but was using Q for current user prompt. --- src/khoj/processor/conversation/prompts.py | 6 +++--- src/khoj/routers/research.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index cc9cb2af..c72122e6 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -484,7 +484,7 @@ Khoj: plan_function_execution = PromptTemplate.from_template( """ -You are a smart, methodical researcher. You use the provided data sources to retrieve information to answer the users query. +You are Khoj, a smart, methodical researcher. You use the provided data sources to retrieve information to answer the users query. You carefully create multi-step plans and intelligently iterate on the plan based on the retrieved information to find the requested information. {personality_context} - Use the data sources provided below, one at a time, if you need to find more information. Their output will be shown to you in the next iteration. @@ -523,8 +523,8 @@ Response format: Chat History: {chat_history} -Q: {query} -Response: +User: {query} +Khoj: """.strip() ) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 6f479a77..7bae6c7d 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -59,7 +59,7 @@ async def apick_next_tool( if len(agent_tools) == 0 or tool.value in agent_tools: tool_options_str += f'- "{tool.value}": "{description}"\n' - chat_history = construct_chat_history(conversation_history) + chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj") if uploaded_image_url: query = f"[placeholder for user attached image]\n{query}" From 6a8fd9bf3370dfe0e945294dc6b4db545ea36469 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 03:57:09 -0700 Subject: [PATCH 33/88] Reorder embeddings search arguments based on argument importance --- src/khoj/database/adapters/__init__.py | 4 ++-- src/khoj/routers/api.py | 2 +- src/khoj/search_type/text_search.py | 6 +++--- tests/test_text_search.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 9687ec01..027490be 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1413,11 +1413,11 @@ class EntryAdapters: @staticmethod def search_with_embeddings( - user: KhojUser, + raw_query: str, embeddings: Tensor, + user: KhojUser, max_results: int = 10, file_type_filter: str = None, - raw_query: str = None, max_distance: float = math.inf, agent: Agent = None, ): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 46fdfa43..6a30e194 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -160,8 +160,8 @@ async def execute_search( search_futures += [ executor.submit( text_search.query, - user, user_query, + user, t, question_embedding=encoded_asymmetric_query, max_distance=max_distance, diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 52e23f29..ae873c33 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -97,8 +97,8 @@ def load_embeddings( async def query( - user: KhojUser, raw_query: str, + user: KhojUser, type: SearchType = SearchType.All, question_embedding: Union[torch.Tensor, None] = None, max_distance: float = None, @@ -125,12 +125,12 @@ async def query( top_k = 10 with timer("Search Time", logger, state.device): hits = EntryAdapters.search_with_embeddings( - user=user, + raw_query=raw_query, embeddings=question_embedding, max_results=top_k, file_type_filter=file_type, - raw_query=raw_query, max_distance=max_distance, + user=user, agent=agent, ).all() hits = await sync_to_async(list)(hits) # type: ignore[call-arg] diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 4529aa53..712f4aba 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -164,7 +164,7 @@ async def test_text_search(search_config: SearchConfig): query = "Load Khoj on Emacs?" # Act - hits = await text_search.query(default_user, query) + hits = await text_search.query(query, default_user) results = text_search.collate_results(hits) results = sorted(results, key=lambda x: float(x.score))[:1] From 564491e164a854a4982a038b8ae9a980efaa26a5 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 03:58:20 -0700 Subject: [PATCH 34/88] Extract date filters quoted with non-ascii quotes in query --- src/khoj/search_filter/date_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/search_filter/date_filter.py b/src/khoj/search_filter/date_filter.py index 91967799..f533a455 100644 --- a/src/khoj/search_filter/date_filter.py +++ b/src/khoj/search_filter/date_filter.py @@ -21,7 +21,7 @@ class DateFilter(BaseFilter): # - dt>="yesterday" dt<"tomorrow" # - dt>="last week" # - dt:"2 years ago" - date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']" + date_regex = r"dt([:><=]{1,2})[\"'‘’](.*?)[\"'‘’]" def __init__(self, entry_key="compiled"): self.entry_key = entry_key From f462d3454776f08206c63d0363df77012a24eb37 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 04:14:21 -0700 Subject: [PATCH 35/88] Render images files output by code interpreter in message on web app --- src/interface/web/app/common/chatFunctions.ts | 43 +++++++++++++++++++ .../components/chatMessage/chatMessage.tsx | 25 +++++++++++ 2 files changed, 68 insertions(+) diff --git a/src/interface/web/app/common/chatFunctions.ts b/src/interface/web/app/common/chatFunctions.ts index b1bc60ed..e6c402a3 100644 --- a/src/interface/web/app/common/chatFunctions.ts +++ b/src/interface/web/app/common/chatFunctions.ts @@ -115,6 +115,33 @@ export function processMessageChunk( if (onlineContext) currentMessage.onlineContext = onlineContext; if (context) currentMessage.context = context; + // Replace file links with base64 data + currentMessage.rawResponse = replaceFileLinksWithBase64( + currentMessage.rawResponse, + codeContext, + ); + + // Add code context files to the message + if (codeContext) { + Object.entries(codeContext).forEach(([key, value]) => { + value.results.output_files?.forEach((file) => { + if (file.filename.endsWith(".png") || file.filename.endsWith(".jpg")) { + // Don't add the image again if it's already in the message! + if (!currentMessage.rawResponse.includes(`![${file.filename}](`)) { + currentMessage.rawResponse += `\n\n![${file.filename}](data:image/png;base64,${file.b64_data})`; + } + } else if ( + file.filename.endsWith(".txt") || + file.filename.endsWith(".org") || + file.filename.endsWith(".md") + ) { + const decodedText = atob(file.b64_data); + currentMessage.rawResponse += `\n\n\`\`\`\n${decodedText}\n\`\`\``; + } + }); + }); + } + // Mark current message streaming as completed currentMessage.completed = true; } @@ -159,6 +186,22 @@ export function handleImageResponse(imageJson: any, liveStream: boolean): Respon return reference; } +export function replaceFileLinksWithBase64(message: string, codeContext: CodeContext) { + if (!codeContext) return message; + + Object.values(codeContext).forEach((contextData) => { + contextData.results.output_files?.forEach((file) => { + const regex = new RegExp(`!\\[.*?\\]\\(.*${file.filename}\\)`, "g"); + if (file.filename.match(/\.(png|jpg|jpeg|gif|webp)$/i)) { + const replacement = `![${file.filename}](data:image/${file.filename.split(".").pop()};base64,${file.b64_data})`; + message = message.replace(regex, replacement); + } + }); + }); + + return message; +} + export function modifyFileFilterForConversation( conversationId: string | null, filenames: string[], diff --git a/src/interface/web/app/components/chatMessage/chatMessage.tsx b/src/interface/web/app/components/chatMessage/chatMessage.tsx index b75c852e..75c2685b 100644 --- a/src/interface/web/app/components/chatMessage/chatMessage.tsx +++ b/src/interface/web/app/components/chatMessage/chatMessage.tsx @@ -10,6 +10,7 @@ import { createRoot } from "react-dom/client"; import "katex/dist/katex.min.css"; import { TeaserReferencesSection, constructAllReferences } from "../referencePanel/referencePanel"; +import { replaceFileLinksWithBase64 } from "@/app/common/chatFunctions"; import { ThumbsUp, @@ -377,6 +378,30 @@ const ChatMessage = forwardRef((props, ref) => message += `\n\n${props.chatMessage.intent["inferred-queries"][0]}`; } + // Replace file links with base64 data + message = replaceFileLinksWithBase64(message, props.chatMessage.codeContext); + + // Add code context files to the message + if (props.chatMessage.codeContext) { + Object.entries(props.chatMessage.codeContext).forEach(([key, value]) => { + value.results.output_files?.forEach((file) => { + if (file.filename.endsWith(".png") || file.filename.endsWith(".jpg")) { + // Don't add the image again if it's already in the message! + if (!message.includes(`![${file.filename}](`)) { + message += `\n\n![${file.filename}](data:image/png;base64,${file.b64_data})`; + } + } else if ( + file.filename.endsWith(".txt") || + file.filename.endsWith(".org") || + file.filename.endsWith(".md") + ) { + const decodedText = atob(file.b64_data); + message += `\n\n\`\`\`\n${decodedText}\n\`\`\``; + } + }); + }); + } + setTextRendered(message); // Render the markdown From 1b13d069f525d51741a4a60c80bccf7f61a51a06 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 05:19:27 -0700 Subject: [PATCH 36/88] Pass data collected from various sources to code tool in normal flow too --- src/khoj/routers/api_chat.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 2a9654ef..48ca10f6 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -954,7 +954,9 @@ async def chat( ## Gather Code Results if ConversationCommand.Code in conversation_commands: try: - previous_iteration_history = "" + previous_iteration_history = ( + f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}" + ) async for result in run_code( defiltered_query, meta_log, From 01a58b71a5113bb286bb7f81664b9e0b1997a801 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Thu, 10 Oct 2024 18:06:29 -0700 Subject: [PATCH 37/88] Skip image, code generation if in research mode --- src/khoj/routers/api_chat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 48ca10f6..271a5d52 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -952,7 +952,7 @@ async def chat( ) ## Gather Code Results - if ConversationCommand.Code in conversation_commands: + if ConversationCommand.Code in conversation_commands and pending_research: try: previous_iteration_history = ( f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}" @@ -993,7 +993,7 @@ async def chat( # Generate Output ## Generate Image Output - if ConversationCommand.Image in conversation_commands: + if ConversationCommand.Image in conversation_commands and pending_research: async for result in text_to_image( q, user, From 20d495c43acd0504febcb849070362361c505ac5 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 11 Oct 2024 00:28:56 -0700 Subject: [PATCH 38/88] Update the iterative chat director prompt to generalize across chat models These prompts work across o1 and standard openai model. Works with anthropic and google models as well --- src/khoj/processor/conversation/prompts.py | 67 +++++++++++----------- src/khoj/processor/conversation/utils.py | 8 +-- src/khoj/routers/research.py | 19 +++--- src/khoj/utils/helpers.py | 4 +- 4 files changed, 51 insertions(+), 47 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index c72122e6..6ac0268e 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -484,45 +484,49 @@ Khoj: plan_function_execution = PromptTemplate.from_template( """ -You are Khoj, a smart, methodical researcher. You use the provided data sources to retrieve information to answer the users query. -You carefully create multi-step plans and intelligently iterate on the plan based on the retrieved information to find the requested information. +You are Khoj, a smart, methodical researcher agent. Use the provided tool AIs to answer my query. +Create a multi-step plan and intelligently iterate on the plan based on the retrieved information to find the requested information. {personality_context} -- Use the data sources provided below, one at a time, if you need to find more information. Their output will be shown to you in the next iteration. -- You are allowed upto {max_iterations} iterations to use these data sources to answer the user's question -- If you have enough information to answer the question, then exit execution by returning an empty response. E.g., {{}} -- Ensure the query contains enough context to retrieve relevant information from the data sources. -- Break down the problem into smaller steps. Some examples are provided below assuming you have access to the notes and online data sources: - - If the user asks for the population of their hometown - 1. Try look up their hometown in their notes - 2. Only then try find the population of the city online. - - If the user asks for their computer's specs - 1. Try find the computer model in their notes - 2. Now look up their computer models spec online - - If the user asks what clothes to carry for their upcoming trip - 1. Find the itinerary of their upcoming trip in their notes - 2. Next find the weather forecast at the destination online - 3. Then find if they mention what clothes they own in their notes -Background Context: +# Instructions +- Ask detailed queries to the tool AIs provided below, one at a time, to discover required information or run calculations. Their response will be shown to you in the next iteration. +- Break down your discovery and research process into independent, self-contained steps that can be executed sequentially. +- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to answer my question. +- When you have the required information return an empty JSON object. E.g., {{}} + +# Examples +Assuming you can search my notes and the internet. +- When I ask for the population of my hometown + 1. Try look up my hometown in my notes + 2. Only then try find the population of the city online. +- When I ask for my computer's specs + 1. Try find my computer model in my notes + 2. Now look up my computer model's spec online +- When I ask what clothes to carry for my upcoming trip + 1. Find the itinerary of my upcoming trip in my notes + 2. Next find the weather forecast at the destination online + 3. Then find if I mentioned what clothes I own in my notes + +# Background Context - Current Date: {day_of_week}, {current_date} -- User's Location: {location} -- {username} +- My Location: {location} +- My {username} -Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources: +# Available Tool AIs +Which of the tool AIs listed below would you use to answer my question? You **only** have access to the following tool AIs: {tools} -Provide the data source and associated query in a JSON object. Do not say anything else. - -Previous Iterations: +# Previous Iterations {previous_iterations} -Response format: -{{"data_source": "", "query": ""}} - -Chat History: +# Chat History: {chat_history} +Return the next tool AI to use and the query to ask it. Your response should always be a valid JSON object. Do not say anything else. +Response format: +{{"scratchpad": "", "tool": "", "query": ""}} + User: {query} Khoj: """.strip() @@ -530,11 +534,10 @@ Khoj: previous_iteration = PromptTemplate.from_template( """ -# Iteration {index}: -# --- -- data_source: {data_source} +## Iteration {index}: +- tool: {tool} - query: {query} -- summary: {summary} +- result: {result} """ ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b8960e0b..f875f0af 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -82,14 +82,14 @@ class ThreadedGenerator: class InformationCollectionIteration: def __init__( self, - data_source: str, + tool: str, query: str, context: Dict[str, Dict] = None, onlineContext: dict = None, codeContext: dict = None, summarizedResult: str = None, ): - self.data_source = data_source + self.tool = tool self.query = query self.context = context self.onlineContext = onlineContext @@ -103,9 +103,9 @@ def construct_iteration_history( previous_iterations_history = "" for idx, iteration in enumerate(previous_iterations): iteration_data = previous_iteration_prompt.format( + tool=iteration.tool, query=iteration.query, - data_source=iteration.data_source, - summary=iteration.summarizedResult, + result=iteration.summarizedResult, index=idx + 1, ) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 7bae6c7d..60c22c80 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -100,20 +100,21 @@ async def apick_next_tool( response = response.strip() response = remove_json_codeblock(response) response = json.loads(response) - suggested_data_source = response.get("data_source", None) - suggested_query = response.get("query", None) + selected_tool = response.get("tool", None) + generated_query = response.get("query", None) + scratchpad = response.get("scratchpad", None) logger.info(f"Response for determining relevant tools: {response}") return InformationCollectionIteration( - data_source=suggested_data_source, - query=suggested_query, + tool=selected_tool, + query=generated_query, ) except Exception as e: logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True) return InformationCollectionIteration( - data_source=None, + tool=None, query=None, ) @@ -155,7 +156,7 @@ async def execute_information_collection( previous_iterations_history, MAX_ITERATIONS, ) - if this_iteration.data_source == ConversationCommand.Notes: + if this_iteration.tool == ConversationCommand.Notes: ## Extract Document References compiled_references, inferred_queries, defiltered_query = [], [], None async for result in extract_references_and_questions( @@ -190,7 +191,7 @@ async def execute_information_collection( # TODO Get correct type for compiled across research notes extraction logger.error(f"Error extracting references: {e}", exc_info=True) - elif this_iteration.data_source == ConversationCommand.Online: + elif this_iteration.tool == ConversationCommand.Online: async for result in search_online( this_iteration.query, conversation_history, @@ -209,7 +210,7 @@ async def execute_information_collection( online_results: Dict[str, Dict] = result # type: ignore this_iteration.onlineContext = online_results - elif this_iteration.data_source == ConversationCommand.Webpage: + elif this_iteration.tool == ConversationCommand.Webpage: try: async for result in read_webpages( this_iteration.query, @@ -239,7 +240,7 @@ async def execute_information_collection( except Exception as e: logger.error(f"Error reading webpages: {e}", exc_info=True) - elif this_iteration.data_source == ConversationCommand.Code: + elif this_iteration.tool == ConversationCommand.Code: try: async for result in run_code( this_iteration.query, diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index d3978fa4..ac3a53f4 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -350,8 +350,8 @@ tool_descriptions_for_llm = { function_calling_description_for_llm = { ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.", - ConversationCommand.Online: "To search for the latest, up-to-date information from the internet.", - ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.", + ConversationCommand.Online: "To search the internet for the latest, up-to-date information.", + ConversationCommand.Webpage: "To read a webpage url for detailed information from the internet.", ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.", } From 9daaae0fdbcd1417868870a232dd99aeec716921 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 11 Oct 2024 01:11:34 -0700 Subject: [PATCH 39/88] Render inline any image files output by code in message Update regex to also include any links to code generated images that aren't explicitly meant to be displayed inline. This allows folks to download the image (unlike the fake link that doesn't work created by model) --- src/interface/web/app/common/chatFunctions.ts | 6 +++--- .../web/app/components/chatMessage/chatMessage.tsx | 4 ++-- src/khoj/routers/helpers.py | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/interface/web/app/common/chatFunctions.ts b/src/interface/web/app/common/chatFunctions.ts index e6c402a3..6035bf78 100644 --- a/src/interface/web/app/common/chatFunctions.ts +++ b/src/interface/web/app/common/chatFunctions.ts @@ -116,7 +116,7 @@ export function processMessageChunk( if (context) currentMessage.context = context; // Replace file links with base64 data - currentMessage.rawResponse = replaceFileLinksWithBase64( + currentMessage.rawResponse = renderCodeGenImageInline( currentMessage.rawResponse, codeContext, ); @@ -186,12 +186,12 @@ export function handleImageResponse(imageJson: any, liveStream: boolean): Respon return reference; } -export function replaceFileLinksWithBase64(message: string, codeContext: CodeContext) { +export function renderCodeGenImageInline(message: string, codeContext: CodeContext) { if (!codeContext) return message; Object.values(codeContext).forEach((contextData) => { contextData.results.output_files?.forEach((file) => { - const regex = new RegExp(`!\\[.*?\\]\\(.*${file.filename}\\)`, "g"); + const regex = new RegExp(`!?\\[.*?\\]\\(.*${file.filename}\\)`, "g"); if (file.filename.match(/\.(png|jpg|jpeg|gif|webp)$/i)) { const replacement = `![${file.filename}](data:image/${file.filename.split(".").pop()};base64,${file.b64_data})`; message = message.replace(regex, replacement); diff --git a/src/interface/web/app/components/chatMessage/chatMessage.tsx b/src/interface/web/app/components/chatMessage/chatMessage.tsx index 75c2685b..31aa4b48 100644 --- a/src/interface/web/app/components/chatMessage/chatMessage.tsx +++ b/src/interface/web/app/components/chatMessage/chatMessage.tsx @@ -10,7 +10,7 @@ import { createRoot } from "react-dom/client"; import "katex/dist/katex.min.css"; import { TeaserReferencesSection, constructAllReferences } from "../referencePanel/referencePanel"; -import { replaceFileLinksWithBase64 } from "@/app/common/chatFunctions"; +import { renderCodeGenImageInline } from "@/app/common/chatFunctions"; import { ThumbsUp, @@ -379,7 +379,7 @@ const ChatMessage = forwardRef((props, ref) => } // Replace file links with base64 data - message = replaceFileLinksWithBase64(message, props.chatMessage.codeContext); + message = renderCodeGenImageInline(message, props.chatMessage.codeContext); // Add code context files to the message if (props.chatMessage.codeContext) { diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 3b97e694..4814d390 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -452,6 +452,7 @@ async def infer_webpage_urls( # Validate that the response is a non-empty, JSON-serializable list of URLs try: response = response.strip() + response = remove_json_codeblock(response) urls = json.loads(response) valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)} if is_none_or_empty(valid_unique_urls): From 9314f0a398237229fca9881d9e7137bd991a0e1a Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 13 Oct 2024 02:59:10 -0700 Subject: [PATCH 40/88] Fix default chat configs to use user model if no server chat model set Post merge cleanup in advanced reasoning to fallback to user chat model if no server chat model defined for advanced and default --- src/khoj/processor/conversation/helpers.py | 5 +++-- src/khoj/processor/tools/run_code.py | 2 +- src/khoj/routers/research.py | 8 ++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/khoj/processor/conversation/helpers.py b/src/khoj/processor/conversation/helpers.py index 06a8557c..4b7e472c 100644 --- a/src/khoj/processor/conversation/helpers.py +++ b/src/khoj/processor/conversation/helpers.py @@ -18,11 +18,11 @@ async def send_message_to_model_wrapper( system_message: str = "", response_type: str = "text", chat_model_option: ChatModelOptions = None, - subscribed: bool = False, + user: KhojUser = None, uploaded_image_url: str = None, ): conversation_config: ChatModelOptions = ( - chat_model_option or await ConversationAdapters.aget_default_conversation_config() + chat_model_option or await ConversationAdapters.aget_default_conversation_config(user) ) vision_available = conversation_config.vision_enabled @@ -32,6 +32,7 @@ async def send_message_to_model_wrapper( conversation_config = vision_enabled_config vision_available = True + subscribed = await ais_user_subscribed(user) chat_model = conversation_config.chat_model max_tokens = ( conversation_config.subscribed_max_prompt_size diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 384b993c..9da04237 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -97,7 +97,7 @@ async def generate_python_code( code_generation_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", - subscribed=subscribed, + user=user, ) # Validate that the response is a non-empty, JSON-serializable list diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 60c22c80..ed43c864 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) async def apick_next_tool( query: str, conversation_history: dict, - subscribed: bool, + user: KhojUser = None, uploaded_image_url: str = None, location: LocationData = None, user_name: str = None, @@ -86,13 +86,13 @@ async def apick_next_tool( max_iterations=max_iterations, ) - chat_model_option = await ConversationAdapters.aget_advanced_conversation_config() + chat_model_option = await ConversationAdapters.aget_advanced_conversation_config(user) with timer("Chat actor: Infer information sources to refer", logger): response = await send_message_to_model_wrapper( function_planning_prompt, response_type="json_object", - subscribed=subscribed, + user=user, chat_model_option=chat_model_option, ) @@ -148,7 +148,7 @@ async def execute_information_collection( this_iteration = await apick_next_tool( query, conversation_history, - subscribed, + user, uploaded_image_url, location, user_name, From 9356e66b943155b0ccf5913dc6c7cef4963a7709 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 13 Oct 2024 03:02:29 -0700 Subject: [PATCH 41/88] Fix default chat model to use user model if no server chat model set - Advanced chat model should also fallback to user chat model if set - Get conversation config should falback to user chat model if set These assume no server chat model settings is configured --- src/khoj/database/adapters/__init__.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 51a211b6..572db97d 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -937,21 +937,21 @@ class ConversationAdapters: def get_conversation_config(user: KhojUser): subscribed = is_user_subscribed(user) if not subscribed: - return ConversationAdapters.get_default_conversation_config() + return ConversationAdapters.get_default_conversation_config(user) config = UserConversationConfig.objects.filter(user=user).first() if config: return config.setting - return ConversationAdapters.get_advanced_conversation_config() + return ConversationAdapters.get_advanced_conversation_config(user) @staticmethod async def aget_conversation_config(user: KhojUser): subscribed = await ais_user_subscribed(user) if not subscribed: - return await ConversationAdapters.aget_default_conversation_config() + return await ConversationAdapters.aget_default_conversation_config(user) config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() if config: return config.setting - return ConversationAdapters.aget_advanced_conversation_config() + return ConversationAdapters.aget_advanced_conversation_config(user) @staticmethod async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: @@ -1012,22 +1012,22 @@ class ConversationAdapters: return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() @staticmethod - def get_advanced_conversation_config(): + def get_advanced_conversation_config(user: KhojUser): server_chat_settings = ServerChatSettings.objects.first() if server_chat_settings is not None and server_chat_settings.chat_advanced is not None: return server_chat_settings.chat_advanced - return ConversationAdapters.get_default_conversation_config() + return ConversationAdapters.get_default_conversation_config(user) @staticmethod - async def aget_advanced_conversation_config(): + async def aget_advanced_conversation_config(user: KhojUser = None): server_chat_settings: ServerChatSettings = ( await ServerChatSettings.objects.filter() .prefetch_related("chat_advanced", "chat_advanced__openai_config") .afirst() ) - if server_chat_settings is not None or server_chat_settings.chat_advanced is not None: + if server_chat_settings is not None and server_chat_settings.chat_advanced is not None: return server_chat_settings.chat_advanced - return await ConversationAdapters.aget_default_conversation_config() + return await ConversationAdapters.aget_default_conversation_config(user) @staticmethod def create_conversation_from_public_conversation( From 263eee43513dcc882c0cc58c3c17512cb52f2c77 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 13 Oct 2024 03:02:29 -0700 Subject: [PATCH 42/88] Fix default chat model to use user model if no server chat model set - Advanced chat model should also fallback to user chat model if set - Get conversation config should falback to user chat model if set These assume no server chat model settings is configured --- src/khoj/database/adapters/__init__.py | 18 +++++++++--------- src/khoj/processor/tools/online_search.py | 1 - src/khoj/routers/api_chat.py | 1 - 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 2309bcd4..182ce701 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -939,21 +939,21 @@ class ConversationAdapters: def get_conversation_config(user: KhojUser): subscribed = is_user_subscribed(user) if not subscribed: - return ConversationAdapters.get_default_conversation_config() + return ConversationAdapters.get_default_conversation_config(user) config = UserConversationConfig.objects.filter(user=user).first() if config: return config.setting - return ConversationAdapters.get_advanced_conversation_config() + return ConversationAdapters.get_advanced_conversation_config(user) @staticmethod async def aget_conversation_config(user: KhojUser): subscribed = await ais_user_subscribed(user) if not subscribed: - return await ConversationAdapters.aget_default_conversation_config() + return await ConversationAdapters.aget_default_conversation_config(user) config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() if config: return config.setting - return ConversationAdapters.aget_advanced_conversation_config() + return ConversationAdapters.aget_advanced_conversation_config(user) @staticmethod async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: @@ -1014,22 +1014,22 @@ class ConversationAdapters: return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() @staticmethod - def get_advanced_conversation_config(): + def get_advanced_conversation_config(user: KhojUser): server_chat_settings = ServerChatSettings.objects.first() if server_chat_settings is not None and server_chat_settings.chat_advanced is not None: return server_chat_settings.chat_advanced - return ConversationAdapters.get_default_conversation_config() + return ConversationAdapters.get_default_conversation_config(user) @staticmethod - async def aget_advanced_conversation_config(): + async def aget_advanced_conversation_config(user: KhojUser = None): server_chat_settings: ServerChatSettings = ( await ServerChatSettings.objects.filter() .prefetch_related("chat_advanced", "chat_advanced__openai_config") .afirst() ) - if server_chat_settings is not None or server_chat_settings.chat_advanced is not None: + if server_chat_settings is not None and server_chat_settings.chat_advanced is not None: return server_chat_settings.chat_advanced - return await ConversationAdapters.aget_default_conversation_config() + return await ConversationAdapters.aget_default_conversation_config(user) @staticmethod def create_conversation_from_public_conversation( diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 16539b5c..8cb4cf64 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -141,7 +141,6 @@ async def read_webpages( conversation_history: dict, location: LocationData, user: KhojUser, - subscribed: bool = False, send_status_func: Optional[Callable] = None, uploaded_image_url: str = None, agent: Agent = None, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 2674836d..071ff8f1 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -910,7 +910,6 @@ async def chat( meta_log, location, user, - subscribed, partial(send_event, ChatEvent.STATUS), uploaded_image_url=uploaded_image_url, agent=agent, From d6206aa80c0b71381049a127ad2cbecdc058385a Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 14 Oct 2024 17:15:26 -0700 Subject: [PATCH 43/88] Remove deprecated GET chat API endpoint --- src/khoj/routers/api_chat.py | 479 ----------------------------------- 1 file changed, 479 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 071ff8f1..94a069da 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1047,482 +1047,3 @@ async def chat( response_iterator = event_generator(q, image=image) response_data = await read_chat_stream(response_iterator) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) - - -# Deprecated API. Remove by end of September 2024 -@api_chat.get("") -@requires(["authenticated"]) -async def get_chat( - request: Request, - common: CommonQueryParams, - q: str, - n: int = 7, - d: float = None, - stream: Optional[bool] = False, - title: Optional[str] = None, - conversation_id: Optional[str] = None, - city: Optional[str] = None, - region: Optional[str] = None, - country: Optional[str] = None, - timezone: Optional[str] = None, - image: Optional[str] = None, - rate_limiter_per_minute=Depends( - ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute") - ), - rate_limiter_per_day=Depends( - ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") - ), -): - # Issue a deprecation warning - warnings.warn( - "The 'get_chat' API endpoint is deprecated. It will be removed by the end of September 2024.", - DeprecationWarning, - stacklevel=2, - ) - - async def event_generator(q: str, image: str): - start_time = time.perf_counter() - ttft = None - chat_metadata: dict = {} - connection_alive = True - user: KhojUser = request.user.object - subscribed: bool = has_required_scope(request, ["premium"]) - event_delimiter = "␃🔚␗" - q = unquote(q) - nonlocal conversation_id - - uploaded_image_url = None - if image: - decoded_string = unquote(image) - base64_data = decoded_string.split(",", 1)[1] - image_bytes = base64.b64decode(base64_data) - webp_image_bytes = convert_image_to_webp(image_bytes) - try: - uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id) - except: - uploaded_image_url = None - - async def send_event(event_type: ChatEvent, data: str | dict): - nonlocal connection_alive, ttft - if not connection_alive or await request.is_disconnected(): - connection_alive = False - logger.warn(f"User {user} disconnected from {common.client} client") - return - try: - if event_type == ChatEvent.END_LLM_RESPONSE: - collect_telemetry() - if event_type == ChatEvent.START_LLM_RESPONSE: - ttft = time.perf_counter() - start_time - if event_type == ChatEvent.MESSAGE: - yield data - elif event_type == ChatEvent.REFERENCES or stream: - yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) - except asyncio.CancelledError as e: - connection_alive = False - logger.warn(f"User {user} disconnected from {common.client} client: {e}") - return - except Exception as e: - connection_alive = False - logger.error(f"Failed to stream chat API response to {user} on {common.client}: {e}", exc_info=True) - return - finally: - yield event_delimiter - - async def send_llm_response(response: str): - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result - async for result in send_event(ChatEvent.MESSAGE, response): - yield result - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - - def collect_telemetry(): - # Gather chat response telemetry - nonlocal chat_metadata - latency = time.perf_counter() - start_time - cmd_set = set([cmd.value for cmd in conversation_commands]) - chat_metadata = chat_metadata or {} - chat_metadata["conversation_command"] = cmd_set - chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None - chat_metadata["latency"] = f"{latency:.3f}" - chat_metadata["ttft_latency"] = f"{ttft:.3f}" - - logger.info(f"Chat response time to first token: {ttft:.3f} seconds") - logger.info(f"Chat response total time: {latency:.3f} seconds") - update_telemetry_state( - request=request, - telemetry_type="api", - api="chat", - client=request.user.client_app, - user_agent=request.headers.get("user-agent"), - host=request.headers.get("host"), - metadata=chat_metadata, - ) - - conversation_commands = [get_conversation_command(query=q, any_references=True)] - - conversation = await ConversationAdapters.aget_conversation_by_user( - user, client_application=request.user.client_app, conversation_id=conversation_id, title=title - ) - if not conversation: - async for result in send_llm_response(f"Conversation {conversation_id} not found"): - yield result - return - conversation_id = conversation.id - agent = conversation.agent if conversation.agent else None - - await is_ready_to_chat(user) - - user_name = await aget_user_name(user) - location = None - if city or region or country: - location = LocationData(city=city, region=region, country=country) - - if is_query_empty(q): - async for result in send_llm_response("Please ask your query to get started."): - yield result - return - - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - meta_log = conversation.conversation_log - is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask] - - if conversation_commands == [ConversationCommand.Default] or is_automated_task: - conversation_commands = await aget_relevant_information_sources( - q, meta_log, is_automated_task, user=user, uploaded_image_url=uploaded_image_url - ) - conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) - async for result in send_event( - ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}" - ): - yield result - - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url) - async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): - yield result - if mode not in conversation_commands: - conversation_commands.append(mode) - - for cmd in conversation_commands: - await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) - q = q.replace(f"/{cmd.value}", "").strip() - - used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] - file_filters = conversation.file_filters if conversation else [] - # Skip trying to summarize if - if ( - # summarization intent was inferred - ConversationCommand.Summarize in conversation_commands - # and not triggered via slash command - and not used_slash_summarize - # but we can't actually summarize - and len(file_filters) != 1 - ): - conversation_commands.remove(ConversationCommand.Summarize) - elif ConversationCommand.Summarize in conversation_commands: - response_log = "" - if len(file_filters) == 0: - response_log = "No files selected for summarization. Please add files using the section on the left." - async for result in send_llm_response(response_log): - yield result - elif len(file_filters) > 1: - response_log = "Only one file can be selected for summarization." - async for result in send_llm_response(response_log): - yield result - else: - try: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - if len(file_object) == 0: - response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." - async for result in send_llm_response(response_log): - yield result - return - contextual_data = " ".join([file.raw_text for file in file_object]) - if not q: - q = "Create a general summary of the file" - async for result in send_event( - ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}" - ): - yield result - - response = await extract_relevant_summary( - q, - contextual_data, - conversation_history=meta_log, - user=user, - uploaded_image_url=uploaded_image_url, - ) - response_log = str(response) - async for result in send_llm_response(response_log): - yield result - except Exception as e: - response_log = "Error summarizing file." - logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - async for result in send_llm_response(response_log): - yield result - await sync_to_async(save_to_conversation_log)( - q, - response_log, - user, - meta_log, - user_message_time, - intent_type="summarize", - client_application=request.user.client_app, - conversation_id=conversation_id, - uploaded_image_url=uploaded_image_url, - ) - return - - custom_filters = [] - if conversation_commands == [ConversationCommand.Help]: - if not q: - conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() - model_type = conversation_config.model_type - formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) - async for result in send_llm_response(formatted_help): - yield result - return - # Adding specification to search online specifically on khoj.dev pages. - custom_filters.append("site:khoj.dev") - conversation_commands.append(ConversationCommand.Online) - - if ConversationCommand.Automation in conversation_commands: - try: - automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log - ) - except Exception as e: - logger.error(f"Error scheduling task {q} for {user.email}: {e}") - error_message = f"Unable to create automation. Ensure the automation doesn't already exist." - async for result in send_llm_response(error_message): - yield result - return - - llm_response = construct_automation_created_message(automation, crontime, query_to_run, subject) - await sync_to_async(save_to_conversation_log)( - q, - llm_response, - user, - meta_log, - user_message_time, - intent_type="automation", - client_application=request.user.client_app, - conversation_id=conversation_id, - inferred_queries=[query_to_run], - automation_id=automation.id, - uploaded_image_url=uploaded_image_url, - ) - async for result in send_llm_response(llm_response): - yield result - return - - # Gather Context - ## Extract Document References - compiled_references, inferred_queries, defiltered_query = [], [], None - async for result in extract_references_and_questions( - request, - meta_log, - q, - (n or 7), - d, - conversation_id, - conversation_commands, - location, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] - - if not is_none_or_empty(compiled_references): - headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) - # Strip only leading # from headings - headings = headings.replace("#", "") - async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"): - yield result - - online_results: Dict = dict() - - if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - async for result in send_llm_response(f"{no_entries_found.format()}"): - yield result - return - - if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references): - conversation_commands.remove(ConversationCommand.Notes) - - ## Gather Online References - if ConversationCommand.Online in conversation_commands: - try: - async for result in search_online( - defiltered_query, - meta_log, - location, - user, - subscribed, - partial(send_event, ChatEvent.STATUS), - custom_filters, - uploaded_image_url=uploaded_image_url, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - online_results = result - except ValueError as e: - error_message = f"Error searching online: {e}. Attempting to respond without online results" - logger.warning(error_message) - async for result in send_llm_response(error_message): - yield result - return - - ## Gather Webpage References - if ConversationCommand.Webpage in conversation_commands: - try: - async for result in read_webpages( - defiltered_query, - meta_log, - location, - user, - subscribed, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - direct_web_pages = result - webpages = [] - for query in direct_web_pages: - if online_results.get(query): - online_results[query]["webpages"] = direct_web_pages[query]["webpages"] - else: - online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} - - for webpage in direct_web_pages[query]["webpages"]: - webpages.append(webpage["link"]) - async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): - yield result - except ValueError as e: - logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", - exc_info=True, - ) - - ## Send Gathered References - async for result in send_event( - ChatEvent.REFERENCES, - { - "inferredQueries": inferred_queries, - "context": compiled_references, - "onlineContext": online_results, - }, - ): - yield result - - # Generate Output - ## Generate Image Output - if ConversationCommand.Image in conversation_commands: - async for result in text_to_image( - q, - user, - meta_log, - location_data=location, - references=compiled_references, - online_results=online_results, - send_status_func=partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - image, status_code, improved_image_prompt, intent_type = result - - if image is None or status_code != 200: - content_obj = { - "content-type": "application/json", - "intentType": intent_type, - "detail": improved_image_prompt, - "image": image, - } - async for result in send_llm_response(json.dumps(content_obj)): - yield result - return - - await sync_to_async(save_to_conversation_log)( - q, - image, - user, - meta_log, - user_message_time, - intent_type=intent_type, - inferred_queries=[improved_image_prompt], - client_application=request.user.client_app, - conversation_id=conversation_id, - compiled_references=compiled_references, - online_results=online_results, - uploaded_image_url=uploaded_image_url, - ) - content_obj = { - "intentType": intent_type, - "inferredQueries": [improved_image_prompt], - "image": image, - } - async for result in send_llm_response(json.dumps(content_obj)): - yield result - return - - ## Generate Text Output - async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"): - yield result - llm_response, chat_metadata = await agenerate_chat_response( - defiltered_query, - meta_log, - conversation, - compiled_references, - online_results, - inferred_queries, - conversation_commands, - user, - request.user.client_app, - conversation_id, - location, - user_name, - uploaded_image_url, - ) - - # Send Response - async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): - yield result - - continue_stream = True - iterator = AsyncIteratorWrapper(llm_response) - async for item in iterator: - if item is None: - async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): - yield result - logger.debug("Finished streaming response") - return - if not connection_alive or not continue_stream: - continue - try: - async for result in send_event(ChatEvent.MESSAGE, f"{item}"): - yield result - except Exception as e: - continue_stream = False - logger.info(f"User {user} disconnected. Emitting rest of responses to clear thread: {e}") - - ## Stream Text Response - if stream: - return StreamingResponse(event_generator(q, image=image), media_type="text/plain") - ## Non-Streaming Text Response - else: - response_iterator = event_generator(q, image=image) - response_data = await read_chat_stream(response_iterator) - return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) From 07ab7ebf07045e63f5c9d6dcf388a2dceae7224e Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 14 Oct 2024 17:39:44 -0700 Subject: [PATCH 44/88] Try respond even if document search via inference endpoint fails The huggingface endpoint can be flaky. Khoj shouldn't refuse to respond to user if document search fails. It should transparently mention that document lookup failed. But try respond as best as it can without the document references This changes provides graceful failover when inference endpoint requests fail either when encoding query or reranking retrieved docs --- src/khoj/processor/embeddings.py | 1 + src/khoj/routers/api_chat.py | 47 +++++++++++++++++------------ src/khoj/search_type/text_search.py | 9 ++++-- 3 files changed, 35 insertions(+), 22 deletions(-) diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index 71af5b7d..a19d85fa 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -114,6 +114,7 @@ class CrossEncoderModel: payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}} headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} response = requests.post(target_url, json=payload, headers=headers) + response.raise_for_status() return response.json()["scores"] cross_inp = [[query, hit.additional[key]] for hit in hits] diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 94a069da..03bf5f50 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -3,7 +3,6 @@ import base64 import json import logging import time -import warnings from datetime import datetime from functools import partial from typing import Dict, Optional @@ -840,25 +839,33 @@ async def chat( # Gather Context ## Extract Document References compiled_references, inferred_queries, defiltered_query = [], [], None - async for result in extract_references_and_questions( - request, - meta_log, - q, - (n or 7), - d, - conversation_id, - conversation_commands, - location, - partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, - agent=agent, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] + try: + async for result in extract_references_and_questions( + request, + meta_log, + q, + (n or 7), + d, + conversation_id, + conversation_commands, + location, + partial(send_event, ChatEvent.STATUS), + uploaded_image_url=uploaded_image_url, + agent=agent, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + compiled_references.extend(result[0]) + inferred_queries.extend(result[1]) + defiltered_query = result[2] + except Exception as e: + error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references." + logger.warning(error_message) + async for result in send_event( + ChatEvent.STATUS, "Document search failed. I'll try respond without document references" + ): + yield result if not is_none_or_empty(compiled_references): headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 52e23f29..b67132e4 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -3,6 +3,7 @@ import math from pathlib import Path from typing import List, Optional, Tuple, Type, Union +import requests import torch from asgiref.sync import sync_to_async from sentence_transformers import util @@ -231,8 +232,12 @@ def setup( def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]: """Score all retrieved entries using the cross-encoder""" - with timer("Cross-Encoder Predict Time", logger, state.device): - cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits) + try: + with timer("Cross-Encoder Predict Time", logger, state.device): + cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits) + except requests.exceptions.HTTPError as e: + logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True) + cross_scores = [0.0] * len(hits) # Convert cross-encoder scores to distances and pass in hits for reranking for idx in range(len(cross_scores)): From 3c93f07b3ff4b6b373fe68470b9bb87c65fe4205 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 14 Oct 2024 17:44:46 -0700 Subject: [PATCH 45/88] Try respond even if web search, webpage read fails during chat Khoj shouldn't refuse to respond to user if web lookups fail. It should transparently mention that online search etc. failed. But try respond as best as it can without those references This change ensures a response to the users query is attempted even when web info retrieval fails. --- src/khoj/routers/api_chat.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 03bf5f50..29f2e737 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -902,12 +902,13 @@ async def chat( yield result[ChatEvent.STATUS] else: online_results = result - except ValueError as e: + except Exception as e: error_message = f"Error searching online: {e}. Attempting to respond without online results" logger.warning(error_message) - async for result in send_llm_response(error_message): + async for result in send_event( + ChatEvent.STATUS, "Online search failed. I'll try respond without online references" + ): yield result - return ## Gather Webpage References if ConversationCommand.Webpage in conversation_commands: @@ -936,11 +937,15 @@ async def chat( webpages.append(webpage["link"]) async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"): yield result - except ValueError as e: + except Exception as e: logger.warning( - f"Error directly reading webpages: {e}. Attempting to respond without online results", + f"Error reading webpages: {e}. Attempting to respond without webpage results", exc_info=True, ) + async for result in send_event( + ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references" + ): + yield result ## Send Gathered References async for result in send_event( From 81fb65fa0a708d9531a422cd97860492a98a5657 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 14 Oct 2024 18:14:40 -0700 Subject: [PATCH 46/88] Return data sources to use if exception in data source chat actor Previously no value was returned if an exception got triggered when collecting information sources to search. --- src/khoj/routers/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 245fdf09..a80864ba 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -353,13 +353,13 @@ async def aget_relevant_information_sources( final_response = [ConversationCommand.Default] else: final_response = [ConversationCommand.General] - return final_response - except Exception as e: + except Exception: logger.error(f"Invalid response for determining relevant tools: {response}") if len(agent_tools) == 0: final_response = [ConversationCommand.Default] else: final_response = agent_tools + return final_response async def aget_relevant_output_modes( From 336c6c36894c08ecbed6ca6e1f6d26d1fe3dd069 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 15 Oct 2024 01:08:48 -0700 Subject: [PATCH 47/88] Show tool to use decision for next iteration in train of thought --- src/khoj/routers/research.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index ed43c864..12a733d7 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -44,6 +44,7 @@ async def apick_next_tool( agent: Agent = None, previous_iterations_history: str = None, max_iterations: int = 5, + send_status_func: Optional[Callable] = None, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer. @@ -103,17 +104,22 @@ async def apick_next_tool( selected_tool = response.get("tool", None) generated_query = response.get("query", None) scratchpad = response.get("scratchpad", None) - logger.info(f"Response for determining relevant tools: {response}") + if send_status_func: + determined_tool_message = "**Determined Tool**: " + determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond." + determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else "" + async for event in send_status_func(f"{scratchpad}"): + yield {ChatEvent.STATUS: event} - return InformationCollectionIteration( + yield InformationCollectionIteration( tool=selected_tool, query=generated_query, ) except Exception as e: logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True) - return InformationCollectionIteration( + yield InformationCollectionIteration( tool=None, query=None, ) @@ -143,9 +149,7 @@ async def execute_information_collection( inferred_queries: List[Any] = [] previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) - result: str = "" - - this_iteration = await apick_next_tool( + async for result in apick_next_tool( query, conversation_history, user, @@ -155,7 +159,13 @@ async def execute_information_collection( agent, previous_iterations_history, MAX_ITERATIONS, - ) + send_status_func, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + this_iteration = result + if this_iteration.tool == ConversationCommand.Notes: ## Extract Document References compiled_references, inferred_queries, defiltered_query = [], [], None From c6c48cfc18cb712fc9aad285bdbfa8d8bd5affe0 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 17 Oct 2024 13:34:56 -0700 Subject: [PATCH 48/88] Fix arg to generate_summary_from_file and type of this_iteration --- src/khoj/processor/conversation/utils.py | 2 +- src/khoj/routers/helpers.py | 2 +- src/khoj/routers/research.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index f875f0af..6546595e 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -84,7 +84,7 @@ class InformationCollectionIteration: self, tool: str, query: str, - context: Dict[str, Dict] = None, + context: list = None, onlineContext: dict = None, codeContext: dict = None, summarizedResult: str = None, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index cd051138..b05a0642 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -647,8 +647,8 @@ async def generate_summary_from_files( q, contextual_data, conversation_history=meta_log, - subscribed=subscribed, uploaded_image_url=uploaded_image_url, + user=user, agent=agent, ) response_log = str(response) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index b8f4301a..80a673c9 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -147,6 +147,7 @@ async def execute_information_collection( code_results: Dict = dict() compiled_references: List[Any] = [] inferred_queries: List[Any] = [] + this_iteration = InformationCollectionIteration(tool=None, query=query) previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) async for result in apick_next_tool( @@ -163,7 +164,7 @@ async def execute_information_collection( ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] - else: + elif isinstance(result, InformationCollectionIteration): this_iteration = result if this_iteration.tool == ConversationCommand.Notes: From 12b32a3d04b1f2e3bc2981dba223205ff1cbbb80 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 23 Oct 2024 19:30:47 -0700 Subject: [PATCH 49/88] Resolve merge conflicts --- src/khoj/processor/conversation/utils.py | 3 +++ src/khoj/processor/tools/online_search.py | 13 +++++++++---- src/khoj/routers/helpers.py | 23 +---------------------- src/khoj/routers/research.py | 4 +--- 4 files changed, 14 insertions(+), 29 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 76de46ec..e8ef32f8 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -122,6 +122,9 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): chat_history += f"User: {chat['intent']['query']}\n" chat_history += f"{agent_name}: [generated image redacted for space]\n" + elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")): + chat_history += f"User: {chat['intent']['query']}\n" + chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n" return chat_history diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 7a098721..739d4c70 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -4,7 +4,7 @@ import logging import os import urllib.parse from collections import defaultdict -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import aiohttp from bs4 import BeautifulSoup @@ -94,11 +94,16 @@ async def search_online( # Gather distinct web pages from organic results for subqueries without an instant answer. # Content of web pages is directly available when Jina is used for search. - webpages = set() + webpages: Dict[str, Dict] = {} for subquery in response_dict: + if "answerBox" in response_dict[subquery]: + continue for organic in response_dict[subquery].get("organic", [])[:max_webpages_to_read]: - if "answerBox" not in response_dict[subquery]: - webpages.add(organic.get("link"), {"queries": {subquery}, "content": organic.get("content")}) + link = organic.get("link") + if link in webpages: + webpages[link]["queries"].add(subquery) + else: + webpages[link] = {"queries": {subquery}, "content": organic.get("content")} # Read, extract relevant info from the retrieved web pages if webpages: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a28d8ec2..2585f77d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -82,7 +82,6 @@ from khoj.processor.conversation.google.gemini_chat import ( converse_gemini, gemini_send_message_to_model, ) -from khoj.processor.conversation.helpers import send_message_to_model_wrapper from khoj.processor.conversation.offline.chat_model import ( converse_offline, send_message_to_model_offline, @@ -214,21 +213,6 @@ def get_next_url(request: Request) -> str: return urljoin(str(request.base_url).rstrip("/"), next_path) -def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: - chat_history = "" - for chat in conversation_history.get("chat", [])[-n:]: - if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]: - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: {chat['message']}\n" - elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: [generated image redacted for space]\n" - elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n" - return chat_history - - def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand: if query.startswith("/notes"): return ConversationCommand.Notes @@ -1129,9 +1113,9 @@ def generate_chat_response( if conversation_config.model_type == "offline": loaded_model = state.offline_chat_processor_config.loaded_model chat_response = converse_offline( + user_query=query_to_run, references=compiled_references, online_results=online_results, - user_query=query_to_run, loaded_model=loaded_model, conversation_log=meta_log, completion_func=partial_completion, @@ -1151,7 +1135,6 @@ def generate_chat_response( chat_response = converse( compiled_references, query_to_run, - q, query_images=query_images, online_results=online_results, code_results=code_results, @@ -1195,10 +1178,6 @@ def generate_chat_response( online_results, code_results, meta_log, - q, - query_images=query_images, - online_results=online_results, - conversation_log=meta_log, model=conversation_config.chat_model, api_key=api_key, completion_func=partial_completion, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 84577ca7..d1578b5b 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -87,14 +87,12 @@ async def apick_next_tool( max_iterations=max_iterations, ) - chat_model_option = await ConversationAdapters.aget_advanced_conversation_config(user) - with timer("Chat actor: Infer information sources to refer", logger): response = await send_message_to_model_wrapper( function_planning_prompt, response_type="json_object", user=user, - chat_model_option=chat_model_option, + query_images=query_images, ) try: From 5acf40c440fc7aec95ba244946dd468e450f7a38 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 23 Oct 2024 20:06:04 -0700 Subject: [PATCH 50/88] Clean up summarization code paths Use assumption of summarization response being a str --- src/khoj/routers/api_chat.py | 7 ++-- src/khoj/routers/helpers.py | 15 +++------ src/khoj/routers/research.py | 65 +++++++++++++----------------------- 3 files changed, 32 insertions(+), 55 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 69196e83..32513b55 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -704,7 +704,7 @@ async def chat( location=location, file_filters=conversation.file_filters if conversation else [], ): - if type(research_result) == InformationCollectionIteration: + if isinstance(research_result, InformationCollectionIteration): if research_result.summarizedResult: pending_research = False if research_result.onlineContext: @@ -778,12 +778,13 @@ async def chat( query_images=uploaded_images, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), - send_response_func=partial(send_llm_response), ): if isinstance(response, dict) and ChatEvent.STATUS in response: yield result[ChatEvent.STATUS] else: - response + if isinstance(response, str): + async for result in send_llm_response(response): + yield result await sync_to_async(save_to_conversation_log)( q, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 2585f77d..bfe25fe3 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -623,7 +623,6 @@ async def generate_summary_from_files( query_images: List[str] = None, agent: Agent = None, send_status_func: Optional[Callable] = None, - send_response_func: Optional[Callable] = None, ): try: file_object = None @@ -636,11 +635,8 @@ async def generate_summary_from_files( file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) if len(file_object) == 0: - response_log = ( - "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again." - ) - async for result in send_response_func(response_log): - yield result + response_log = "Sorry, I couldn't find the full text of this file." + yield response_log return contextual_data = " ".join([file.raw_text for file in file_object]) if not q: @@ -657,13 +653,12 @@ async def generate_summary_from_files( agent=agent, ) response_log = str(response) - async for result in send_response_func(response_log): - yield result + + yield result except Exception as e: response_log = "Error summarizing file. Please try again, or contact support." logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - async for result in send_response_func(response_log): - yield result + yield result async def generate_excalidraw_diagram( diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index d1578b5b..1beb1f69 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -143,6 +143,7 @@ async def execute_information_collection( online_results: Dict = dict() code_results: Dict = dict() compiled_references: List[Any] = [] + summarize_files: str = "" inferred_queries: List[Any] = [] this_iteration = InformationCollectionIteration(tool=None, query=query) previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) @@ -271,53 +272,31 @@ async def execute_information_collection( exc_info=True, ) - # TODO: Fix summarize later - # elif this_iteration.data_source == ConversationCommand.Summarize: - # response_log = "" - # agent_has_entries = await EntryAdapters.aagent_has_entries(agent) - # if len(file_filters) == 0 and not agent_has_entries: - # previous_iterations.append( - # InformationCollectionIteration( - # data_source=this_iteration.data_source, - # query=this_iteration.query, - # context="No files selected for summarization.", - # ) - # ) - # elif len(file_filters) > 1 and not agent_has_entries: - # response_log = "Only one file can be selected for summarization." - # previous_iterations.append( - # InformationCollectionIteration( - # data_source=this_iteration.data_source, - # query=this_iteration.query, - # context=response_log, - # ) - # ) - # else: - # async for response in generate_summary_from_files( - # q=query, - # user=user, - # file_filters=file_filters, - # meta_log=conversation_history, - # subscribed=subscribed, - # send_status_func=send_status_func, - # ): - # if isinstance(response, dict) and ChatEvent.STATUS in response: - # yield response[ChatEvent.STATUS] - # else: - # response_log = response # type: ignore - # previous_iterations.append( - # InformationCollectionIteration( - # data_source=this_iteration.data_source, - # query=this_iteration.query, - # context=response_log, - # ) - # ) + elif this_iteration.tool == ConversationCommand.Summarize: + try: + async for result in generate_summary_from_files( + this_iteration.query, + user, + file_filters, + conversation_history, + query_images=query_images, + agent=agent, + send_status_func=send_status_func, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + summarize_files = result # type: ignore + except Exception as e: + logger.error(f"Error generating summary: {e}", exc_info=True) + else: + # No valid tools. This is our exit condition. current_iteration = MAX_ITERATIONS current_iteration += 1 - if compiled_references or online_results or code_results: + if compiled_references or online_results or code_results or summarize_files: results_data = f"**Results**:\n" if compiled_references: results_data += f"**Document References**: {compiled_references}\n" @@ -325,6 +304,8 @@ async def execute_information_collection( results_data += f"**Online Results**: {online_results}\n" if code_results: results_data += f"**Code Results**: {code_results}\n" + if summarize_files: + results_data += f"**Summarized Files**: {summarize_files}\n" # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) this_iteration.summarizedResult = results_data From a11b5293fbe535fb87019fac6de087550b3de511 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 23 Oct 2024 21:45:17 -0700 Subject: [PATCH 51/88] Add uploaded images to research mode, code slash command, include code references --- .../components/chatMessage/chatMessage.tsx | 3 +++ src/khoj/routers/api_chat.py | 25 ++++++++++--------- src/khoj/routers/helpers.py | 2 ++ 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/interface/web/app/components/chatMessage/chatMessage.tsx b/src/interface/web/app/components/chatMessage/chatMessage.tsx index 85bcc517..1c7757b4 100644 --- a/src/interface/web/app/components/chatMessage/chatMessage.tsx +++ b/src/interface/web/app/components/chatMessage/chatMessage.tsx @@ -301,6 +301,9 @@ export function TrainOfThought(props: TrainOfThoughtProps) { const iconColor = props.primary ? convertColorToTextClass(props.agentColor) : "text-gray-500"; const icon = chooseIconFromHeader(header, iconColor); let markdownRendered = DOMPurify.sanitize(md.render(props.message)); + + // Remove any header tags from markdownRendered + markdownRendered = markdownRendered.replace(//g, ""); return (
Conver return ConversationCommand.Summarize elif query.startswith("/diagram"): return ConversationCommand.Diagram + elif query.startswith("/code"): + return ConversationCommand.Code # If no relevant notes found for the given query elif not any_references: return ConversationCommand.General From 0f3927e810ae178e662092dcd76389209f65310e Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 26 Oct 2024 05:59:10 -0700 Subject: [PATCH 52/88] Send gathered references to client after code results calculated --- src/khoj/routers/api_chat.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index e80d215f..d0b78d9a 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -973,18 +973,6 @@ async def chat( ): yield result - ## Send Gathered References - async for result in send_event( - ChatEvent.REFERENCES, - { - "inferredQueries": inferred_queries, - "context": compiled_references, - "onlineContext": online_results, - "codeContext": code_results, - }, - ): - yield result - if pending_research: ## Gather Code Results if ConversationCommand.Code in conversation_commands and pending_research: @@ -1015,6 +1003,18 @@ async def chat( exc_info=True, ) + ## Send Gathered References + async for result in send_event( + ChatEvent.REFERENCES, + { + "inferredQueries": inferred_queries, + "context": compiled_references, + "onlineContext": online_results, + "codeContext": code_results, + }, + ): + yield result + # Generate Output ## Generate Image Output if ConversationCommand.Image in conversation_commands: From 3e97ebf0c79c0a8095f0b66d9067345de88e8a57 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 26 Oct 2024 10:45:42 -0700 Subject: [PATCH 53/88] Unescape special characters in prompt traces for better readability --- src/khoj/processor/conversation/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 92192f52..760c422e 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -502,6 +502,8 @@ def commit_conversation_trace( # Write files and stage them for filename, content in files_to_commit.items(): file_path = os.path.join(repo_path, filename) + # Unescape special characters in content for better readability + content = content.strip().replace("\\n", "\n").replace("\\t", "\t") with open(file_path, "w", encoding="utf-8") as f: f.write(content) repo.index.add([filename]) From bf96d8194331b75b760cf7d23cdd4a6f3ea9cdbd Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 26 Oct 2024 10:39:45 -0700 Subject: [PATCH 54/88] Format online results as YAML to pass it in more readable form to model Previous passing of online results as json dump in prompts was less readable for humans, and I'm guessing less readable for models (trained on human data) as well? --- src/khoj/routers/research.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 8221fd5c..94e19868 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -3,6 +3,7 @@ import logging from datetime import datetime from typing import Any, Callable, Dict, List, Optional +import yaml from fastapi import Request from khoj.database.adapters import ConversationAdapters, EntryAdapters @@ -307,13 +308,13 @@ async def execute_information_collection( if compiled_references or online_results or code_results or summarize_files: results_data = f"**Results**:\n" if compiled_references: - results_data += f"**Document References**: {compiled_references}\n" + results_data += f"**Document References**: {yaml.dump(compiled_references, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if online_results: - results_data += f"**Online Results**: {online_results}\n" + results_data += f"**Online Results**: {yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if code_results: - results_data += f"**Code Results**: {code_results}\n" + results_data += f"**Code Results**: {yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if summarize_files: - results_data += f"**Summarized Files**: {summarize_files}\n" + results_data += f"**Summarized Files**: {yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) this_iteration.summarizedResult = results_data From 3e5b5ec122e4afe0d70ea7e0b6f1d1212c09a894 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 26 Oct 2024 10:41:42 -0700 Subject: [PATCH 55/88] Encourage model to read webpages more often after online search Previously model would rarely read webpages after webpage search. Need the model to webpages more regularly for deeper research and to stop getting stuck in repetitive online search loops --- src/khoj/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index d2e4dd8f..a565ee76 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -357,7 +357,7 @@ tool_descriptions_for_llm = { function_calling_description_for_llm = { ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.", ConversationCommand.Online: "To search the internet for the latest, up-to-date information.", - ConversationCommand.Webpage: "To read a webpage url for detailed information from the internet.", + ConversationCommand.Webpage: "To read a webpage for more detailed research from the internet. Usually used when you have webpage links to refer to.", ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.", } From e4285941d142684fae69fec8268c4313fedf333b Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 26 Oct 2024 16:00:54 -0700 Subject: [PATCH 56/88] Use the advanced chat model if the user is subscribed --- src/khoj/database/adapters/__init__.py | 26 +++++++++++++++++++++----- src/khoj/routers/api_chat.py | 2 +- src/khoj/routers/helpers.py | 2 +- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 4490b7d2..a2c531f8 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1009,8 +1009,15 @@ class ConversationAdapters: """Get default conversation config. Prefer chat model by server admin > user > first created chat model""" # Get the server chat settings server_chat_settings = ServerChatSettings.objects.first() - if server_chat_settings is not None and server_chat_settings.chat_default is not None: - return server_chat_settings.chat_default + + is_subscribed = is_user_subscribed(user) if user else False + if server_chat_settings: + # If the user is subscribed and the advanced model is enabled, return the advanced model + if is_subscribed and server_chat_settings.chat_advanced: + return server_chat_settings.chat_advanced + # If the default model is set, return it + if server_chat_settings.chat_default: + return server_chat_settings.chat_default # Get the user's chat settings, if the server chat settings are not set user_chat_settings = UserConversationConfig.objects.filter(user=user).first() if user else None @@ -1026,11 +1033,20 @@ class ConversationAdapters: # Get the server chat settings server_chat_settings: ServerChatSettings = ( await ServerChatSettings.objects.filter() - .prefetch_related("chat_default", "chat_default__openai_config") + .prefetch_related( + "chat_default", "chat_default__openai_config", "chat_advanced", "chat_advanced__openai_config" + ) .afirst() ) - if server_chat_settings is not None and server_chat_settings.chat_default is not None: - return server_chat_settings.chat_default + is_subscribed = await ais_user_subscribed(user) if user else False + + if server_chat_settings: + # If the user is subscribed and the advanced model is enabled, return the advanced model + if is_subscribed and server_chat_settings.chat_advanced: + return server_chat_settings.chat_advanced + # If the default model is set, return it + if server_chat_settings.chat_default: + return server_chat_settings.chat_default # Get the user's chat settings, if the server chat settings are not set user_chat_settings = ( diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index d0b78d9a..5ebfd911 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -817,7 +817,7 @@ async def chat( if not q: conversation_config = await ConversationAdapters.aget_user_conversation_config(user) if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() + conversation_config = await ConversationAdapters.aget_default_conversation_config(user) model_type = conversation_config.model_type formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) async for result in send_llm_response(formatted_help): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 0f5a7006..2af1f64d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -140,7 +140,7 @@ def validate_conversation_config(user: KhojUser): async def is_ready_to_chat(user: KhojUser): user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user) if user_conversation_config == None: - user_conversation_config = await ConversationAdapters.aget_default_conversation_config() + user_conversation_config = await ConversationAdapters.aget_default_conversation_config(user) if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: chat_model = user_conversation_config.chat_model From 9e8ac7f89e7316077351b69bb7a0aebfee86826c Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 26 Oct 2024 16:37:58 -0700 Subject: [PATCH 57/88] Fix input/output mismatches in the /summarize command --- src/khoj/database/adapters/__init__.py | 4 +++- src/khoj/routers/api_chat.py | 3 ++- src/khoj/routers/helpers.py | 7 +++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index a2c531f8..14b092d7 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1480,7 +1480,9 @@ class EntryAdapters: @staticmethod async def aget_agent_entry_filepaths(agent: Agent): - return await sync_to_async(list)(Entry.objects.filter(agent=agent).values_list("file_path", flat=True)) + return await sync_to_async(set)( + Entry.objects.filter(agent=agent).distinct("file_path").values_list("file_path", flat=True) + ) @staticmethod def get_all_filenames_by_source(user: KhojUser, file_source: str): diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 5ebfd911..69894a68 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -792,9 +792,10 @@ async def chat( tracer=tracer, ): if isinstance(response, dict) and ChatEvent.STATUS in response: - yield result[ChatEvent.STATUS] + yield response[ChatEvent.STATUS] else: if isinstance(response, str): + response_log = response async for result in send_llm_response(response): yield result diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 2af1f64d..c648c12b 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -650,7 +650,7 @@ async def generate_summary_from_files( if await EntryAdapters.aagent_has_entries(agent): file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) if len(file_names) > 0: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names[0], agent) + file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) if len(file_filters) > 0: file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) @@ -663,7 +663,7 @@ async def generate_summary_from_files( if not q: q = "Create a general summary of the file" async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"): - yield result + yield {ChatEvent.STATUS: result} response = await extract_relevant_summary( q, @@ -674,9 +674,8 @@ async def generate_summary_from_files( agent=agent, tracer=tracer, ) - response_log = str(response) - yield result + yield str(response) except Exception as e: response_log = "Error summarizing file. Please try again, or contact support." logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) From a121d67b102bbab0bfb35eb70541387340077edc Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 26 Oct 2024 23:46:15 -0700 Subject: [PATCH 58/88] Persist the train of thought in the conversation history --- .../components/chatHistory/chatHistory.tsx | 135 ++++++++++++------ .../components/chatMessage/chatMessage.tsx | 6 + src/khoj/processor/conversation/utils.py | 10 +- src/khoj/routers/api_chat.py | 14 +- src/khoj/routers/helpers.py | 2 + 5 files changed, 117 insertions(+), 50 deletions(-) diff --git a/src/interface/web/app/components/chatHistory/chatHistory.tsx b/src/interface/web/app/components/chatHistory/chatHistory.tsx index 31e6c20d..a373da34 100644 --- a/src/interface/web/app/components/chatHistory/chatHistory.tsx +++ b/src/interface/web/app/components/chatHistory/chatHistory.tsx @@ -13,13 +13,14 @@ import { ScrollArea } from "@/components/ui/scroll-area"; import { InlineLoading } from "../loading/loading"; -import { Lightbulb, ArrowDown } from "@phosphor-icons/react"; +import { Lightbulb, ArrowDown, XCircle } from "@phosphor-icons/react"; import AgentProfileCard from "../profileCard/profileCard"; import { getIconFromIconName } from "@/app/common/iconUtils"; import { AgentData } from "@/app/agents/page"; import React from "react"; import { useIsMobileWidth } from "@/app/common/utils"; +import { Button } from "@/components/ui/button"; interface ChatResponse { status: string; @@ -40,26 +41,51 @@ interface ChatHistoryProps { customClassName?: string; } -function constructTrainOfThought( - trainOfThought: string[], - lastMessage: boolean, - agentColor: string, - key: string, - completed: boolean = false, -) { - const lastIndex = trainOfThought.length - 1; - return ( -
- {!completed && } +interface TrainOfThoughtComponentProps { + trainOfThought: string[]; + lastMessage: boolean; + agentColor: string; + key: string; + completed?: boolean; +} - {trainOfThought.map((train, index) => ( - - ))} +function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) { + const lastIndex = props.trainOfThought.length - 1; + const [collapsed, setCollapsed] = useState(props.completed); + + return ( +
+ {!props.completed && } + {collapsed ? ( + + ) : ( + + )} + + {!collapsed && + props.trainOfThought.map((train, index) => ( + + ))}
); } @@ -265,25 +291,39 @@ export default function ChatHistory(props: ChatHistoryProps) { {data && data.chat && data.chat.map((chatMessage, index) => ( - + <> + {chatMessage.trainOfThought && chatMessage.by === "khoj" && ( + train.data, + )} + lastMessage={false} + agentColor={data?.agent?.color || "orange"} + key={`${index}trainOfThought`} + completed={true} + /> + )} + + ))} {props.incomingMessages && props.incomingMessages.map((message, index) => { @@ -305,14 +345,15 @@ export default function ChatHistory(props: ChatHistoryProps) { customClassName="fullHistory" borderLeftColor={`${data?.agent?.color}-500`} /> - {message.trainOfThought && - constructTrainOfThought( - message.trainOfThought, - index === incompleteIncomingMessageIndex, - data?.agent?.color || "orange", - `${index}trainOfThought`, - message.completed, - )} + {message.trainOfThought && ( + + )} Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables chat_response = None @@ -1137,6 +1138,7 @@ def generate_chat_response( conversation_id=conversation_id, query_images=query_images, tracer=tracer, + train_of_thought=train_of_thought, ) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) From 0bd78791ca91dd3402d5754a1bf5aaad1a080ab8 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 27 Oct 2024 15:01:49 -0700 Subject: [PATCH 59/88] Let user exit from command mode with esc, click out, etc. --- .../app/components/chatInputArea/chatInputArea.tsx | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/interface/web/app/components/chatInputArea/chatInputArea.tsx b/src/interface/web/app/components/chatInputArea/chatInputArea.tsx index 7f2baf1d..7a40b97f 100644 --- a/src/interface/web/app/components/chatInputArea/chatInputArea.tsx +++ b/src/interface/web/app/components/chatInputArea/chatInputArea.tsx @@ -72,6 +72,8 @@ export const ChatInputArea = forwardRef((pr const [progressValue, setProgressValue] = useState(0); const [isDragAndDropping, setIsDragAndDropping] = useState(false); + const [showCommandList, setShowCommandList] = useState(false); + const chatInputRef = ref as React.MutableRefObject; useEffect(() => { if (!uploading) { @@ -275,6 +277,12 @@ export const ChatInputArea = forwardRef((pr chatInputRef.current.style.height = "auto"; chatInputRef.current.style.height = Math.max(chatInputRef.current.scrollHeight - 24, 64) + "px"; + + if (message.startsWith("/") && message.split(" ").length === 1) { + setShowCommandList(true); + } else { + setShowCommandList(false); + } }, [message]); function handleDragOver(event: React.DragEvent) { @@ -360,9 +368,9 @@ export const ChatInputArea = forwardRef((pr )} - {message.startsWith("/") && message.split(" ").length === 1 && ( + {showCommandList && (
- + e.preventDefault()} From 101ea6efb1c79741a9174ea4d84cff7ed7b9aa65 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 27 Oct 2024 15:47:44 -0700 Subject: [PATCH 60/88] Add research mode as a slash command, remove from default path --- src/khoj/routers/api_chat.py | 39 ++++++++++++++++++------------------ src/khoj/routers/helpers.py | 2 ++ src/khoj/utils/helpers.py | 4 ++++ 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 3caf09aa..54d9576a 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -702,7 +702,26 @@ async def chat( inferred_queries: List[Any] = [] defiltered_query = defilter_query(q) - if conversation_commands == [ConversationCommand.Default]: + if conversation_commands == [ConversationCommand.Default] or is_automated_task: + conversation_commands = await aget_relevant_information_sources( + q, + meta_log, + is_automated_task, + user=user, + query_images=uploaded_images, + agent=agent, + tracer=tracer, + ) + + mode = await aget_relevant_output_modes( + q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer + ) + async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): + yield result + if mode not in conversation_commands: + conversation_commands.append(mode) + + if conversation_commands == [ConversationCommand.Research]: async for research_result in execute_information_collection( request=request, user=user, @@ -738,24 +757,6 @@ async def chat( pending_research = False - conversation_commands = await aget_relevant_information_sources( - q, - meta_log, - is_automated_task, - user=user, - query_images=uploaded_images, - agent=agent, - tracer=tracer, - ) - - mode = await aget_relevant_output_modes( - q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer - ) - async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): - yield result - if mode not in conversation_commands: - conversation_commands.append(mode) - for cmd in conversation_commands: await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) q = q.replace(f"/{cmd.value}", "").strip() diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e8ed3e5e..e6437a7d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -234,6 +234,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver return ConversationCommand.Diagram elif query.startswith("/code"): return ConversationCommand.Code + elif query.startswith("/research"): + return ConversationCommand.Research # If no relevant notes found for the given query elif not any_references: return ConversationCommand.General diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index a565ee76..bdbafbad 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -320,6 +320,7 @@ class ConversationCommand(str, Enum): AutomatedTask = "automated_task" Summarize = "summarize" Diagram = "diagram" + Research = "research" command_descriptions = { @@ -334,6 +335,7 @@ command_descriptions = { ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation", ConversationCommand.Summarize: "Get help with a question pertaining to an entire document.", ConversationCommand.Diagram: "Draw a flowchart, diagram, or any other visual representation best expressed with primitives like lines, rectangles, and text.", + ConversationCommand.Research: "Do deep research on a topic. This will take longer than usual, but give a more detailed, comprehensive answer.", } command_descriptions_for_agent = { @@ -342,6 +344,7 @@ command_descriptions_for_agent = { ConversationCommand.Online: "Agent can search the internet for information.", ConversationCommand.Webpage: "Agent can read suggested web pages for information.", ConversationCommand.Summarize: "Agent can read an entire document. Agents knowledge base must be a single document.", + ConversationCommand.Research: "Agent can do deep research on a topic.", } tool_descriptions_for_llm = { @@ -352,6 +355,7 @@ tool_descriptions_for_llm = { ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.", ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.", ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.", + ConversationCommand.Research: "To use when you need to do DEEP research on a topic. This will take longer than usual, but give a more detailed, comprehensive answer.", } function_calling_description_for_llm = { From 68499e253bf44e6afa2c53517e8d1ff9918d516b Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 27 Oct 2024 15:48:13 -0700 Subject: [PATCH 61/88] Auto-collapse train of thought, show after chat response in history --- .../components/chatHistory/chatHistory.tsx | 69 ++++++++++--------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/src/interface/web/app/components/chatHistory/chatHistory.tsx b/src/interface/web/app/components/chatHistory/chatHistory.tsx index a373da34..870ec19c 100644 --- a/src/interface/web/app/components/chatHistory/chatHistory.tsx +++ b/src/interface/web/app/components/chatHistory/chatHistory.tsx @@ -54,29 +54,32 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) { const [collapsed, setCollapsed] = useState(props.completed); return ( -
+
{!props.completed && } - {collapsed ? ( - - ) : ( - - )} - + {props.completed && + (collapsed ? ( + + ) : ( + + ))} {!collapsed && props.trainOfThought.map((train, index) => ( ( <> - {chatMessage.trainOfThought && chatMessage.by === "khoj" && ( - train.data, - )} - lastMessage={false} - agentColor={data?.agent?.color || "orange"} - key={`${index}trainOfThought`} - completed={true} - /> - )} + {chatMessage.trainOfThought && chatMessage.by === "khoj" && ( + train.data, + )} + lastMessage={false} + agentColor={data?.agent?.color || "orange"} + key={`${index}trainOfThought`} + completed={true} + /> + )} ))} {props.incomingMessages && From 2924909692971d4342dfa3a5cef3367105b15861 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 27 Oct 2024 16:37:40 -0700 Subject: [PATCH 62/88] Add a research mode toggle to the chat input area --- .../components/chatHistory/chatHistory.tsx | 2 +- .../chatInputArea/chatInputArea.tsx | 279 ++++++++++-------- 2 files changed, 164 insertions(+), 117 deletions(-) diff --git a/src/interface/web/app/components/chatHistory/chatHistory.tsx b/src/interface/web/app/components/chatHistory/chatHistory.tsx index 870ec19c..6aa61ff6 100644 --- a/src/interface/web/app/components/chatHistory/chatHistory.tsx +++ b/src/interface/web/app/components/chatHistory/chatHistory.tsx @@ -419,7 +419,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
)}
-
+
{!isNearBottom && (
)} -
- -
- -
-
-
- {imageUploaded && - imagePaths.map((path, index) => ( -
- {`img-${index}`} - -
- ))} -
-