Simplify storing and displaying document results in research mode

- Mention count of notes and files disovered
- Store query associated with each compiled reference retrieved for
  easier referencing
This commit is contained in:
Debanjum 2024-10-28 20:26:59 -07:00
parent 8bd94bf855
commit d3184ae39a
3 changed files with 23 additions and 20 deletions

View file

@ -24,6 +24,7 @@ from khoj.database.adapters import ConversationAdapters, ais_user_subscribed
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.search_filter.base_filter import BaseFilter
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
@ -409,7 +410,8 @@ def remove_json_codeblock(response: str):
def defilter_query(query: str):
"""Remove any query filters in query"""
defiltered_query = query
for filter in [DateFilter(), WordFilter(), FileFilter()]:
filters: List[BaseFilter] = [WordFilter(), FileFilter(), DateFilter()]
for filter in filters:
defiltered_query = filter.defilter(defiltered_query)
return defiltered_query

View file

@ -356,7 +356,7 @@ async def extract_references_and_questions(
user = request.user.object if request.user.is_authenticated else None
# Initialize Variables
compiled_references: List[Any] = []
compiled_references: List[dict[str, str]] = []
inferred_queries: List[str] = []
agent_has_entries = False
@ -501,7 +501,8 @@ async def extract_references_and_questions(
)
search_results = text_search.deduplicated_search_responses(search_results)
compiled_references = [
{"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results
{"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]}
for q, item in zip(inferred_queries, search_results)
]
yield compiled_references, inferred_queries, defiltered_query

View file

@ -146,7 +146,7 @@ async def execute_information_collection(
while current_iteration < MAX_ITERATIONS:
online_results: Dict = dict()
code_results: Dict = dict()
compiled_references: List[Any] = []
document_results: List[Dict[str, str]] = []
summarize_files: str = ""
inferred_queries: List[Any] = []
this_iteration = InformationCollectionIteration(tool=None, query=query)
@ -171,8 +171,8 @@ async def execute_information_collection(
this_iteration = result
if this_iteration.tool == ConversationCommand.Notes:
## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], None
this_iteration.context = []
document_results = []
async for result in extract_references_and_questions(
request,
conversation_history,
@ -189,22 +189,22 @@ async def execute_information_collection(
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
this_iteration.context = compiled_references
elif isinstance(result, tuple):
document_results = result[0]
this_iteration.context += document_results
if not is_none_or_empty(compiled_references):
if not is_none_or_empty(document_results):
try:
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
distinct_files = {d["file"] for d in document_results}
distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d])
# Strip only leading # from headings
headings = headings.replace("#", "")
async for result in send_status_func(f"**Found Relevant Notes**: {headings}"):
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
async for result in send_status_func(
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
):
yield result
except Exception as e:
# TODO Get correct type for compiled across research notes extraction
logger.error(f"Error extracting references: {e}", exc_info=True)
logger.error(f"Error extracting document references: {e}", exc_info=True)
elif this_iteration.tool == ConversationCommand.Online:
async for result in search_online(
@ -305,10 +305,10 @@ async def execute_information_collection(
current_iteration += 1
if compiled_references or online_results or code_results or summarize_files:
if document_results or online_results or code_results or summarize_files:
results_data = f"**Results**:\n"
if compiled_references:
results_data += f"**Document References**: {yaml.dump(compiled_references, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if document_results:
results_data += f"**Document References**: {yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if online_results:
results_data += f"**Online Results**: {yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if code_results: