From d3184ae39ae9c3087542f29dcc0c4c69f67cb1b8 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 28 Oct 2024 20:26:59 -0700 Subject: [PATCH] Simplify storing and displaying document results in research mode - Mention count of notes and files disovered - Store query associated with each compiled reference retrieved for easier referencing --- src/khoj/processor/conversation/utils.py | 4 ++- src/khoj/routers/api.py | 5 ++-- src/khoj/routers/research.py | 34 ++++++++++++------------ 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 4783990d..ea63d81f 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -24,6 +24,7 @@ from khoj.database.adapters import ConversationAdapters, ais_user_subscribed from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens +from khoj.search_filter.base_filter import BaseFilter from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter @@ -409,7 +410,8 @@ def remove_json_codeblock(response: str): def defilter_query(query: str): """Remove any query filters in query""" defiltered_query = query - for filter in [DateFilter(), WordFilter(), FileFilter()]: + filters: List[BaseFilter] = [WordFilter(), FileFilter(), DateFilter()] + for filter in filters: defiltered_query = filter.defilter(defiltered_query) return defiltered_query diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 1e3fb092..8f83d916 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -356,7 +356,7 @@ async def extract_references_and_questions( user = request.user.object if request.user.is_authenticated else None # Initialize Variables - compiled_references: List[Any] = [] + compiled_references: List[dict[str, str]] = [] inferred_queries: List[str] = [] agent_has_entries = False @@ -501,7 +501,8 @@ async def extract_references_and_questions( ) search_results = text_search.deduplicated_search_responses(search_results) compiled_references = [ - {"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results + {"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]} + for q, item in zip(inferred_queries, search_results) ] yield compiled_references, inferred_queries, defiltered_query diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index dcc86c7a..ca5963e3 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -146,7 +146,7 @@ async def execute_information_collection( while current_iteration < MAX_ITERATIONS: online_results: Dict = dict() code_results: Dict = dict() - compiled_references: List[Any] = [] + document_results: List[Dict[str, str]] = [] summarize_files: str = "" inferred_queries: List[Any] = [] this_iteration = InformationCollectionIteration(tool=None, query=query) @@ -171,8 +171,8 @@ async def execute_information_collection( this_iteration = result if this_iteration.tool == ConversationCommand.Notes: - ## Extract Document References - compiled_references, inferred_queries, defiltered_query = [], [], None + this_iteration.context = [] + document_results = [] async for result in extract_references_and_questions( request, conversation_history, @@ -189,22 +189,22 @@ async def execute_information_collection( ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] - else: - compiled_references.extend(result[0]) - inferred_queries.extend(result[1]) - defiltered_query = result[2] - this_iteration.context = compiled_references + elif isinstance(result, tuple): + document_results = result[0] + this_iteration.context += document_results - if not is_none_or_empty(compiled_references): + if not is_none_or_empty(document_results): try: - headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references])) + distinct_files = {d["file"] for d in document_results} + distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d]) # Strip only leading # from headings - headings = headings.replace("#", "") - async for result in send_status_func(f"**Found Relevant Notes**: {headings}"): + headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "") + async for result in send_status_func( + f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}" + ): yield result except Exception as e: - # TODO Get correct type for compiled across research notes extraction - logger.error(f"Error extracting references: {e}", exc_info=True) + logger.error(f"Error extracting document references: {e}", exc_info=True) elif this_iteration.tool == ConversationCommand.Online: async for result in search_online( @@ -305,10 +305,10 @@ async def execute_information_collection( current_iteration += 1 - if compiled_references or online_results or code_results or summarize_files: + if document_results or online_results or code_results or summarize_files: results_data = f"**Results**:\n" - if compiled_references: - results_data += f"**Document References**: {yaml.dump(compiled_references, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + if document_results: + results_data += f"**Document References**: {yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if online_results: results_data += f"**Online Results**: {yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if code_results: