Add summarization capability with agent knowledge base

This commit is contained in:
sabaimran 2024-10-07 21:20:23 -07:00
parent df3dc33e96
commit f700d5bddb
2 changed files with 26 additions and 7 deletions

View file

@ -1214,8 +1214,8 @@ class FileObjectAdapters:
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text) return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod @staticmethod
async def async_get_file_objects_by_name(user: KhojUser, file_name: str): async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name)) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
@staticmethod @staticmethod
async def async_get_all_file_objects(user: KhojUser): async def async_get_all_file_objects(user: KhojUser):
@ -1305,6 +1305,10 @@ class EntryAdapters:
async def auser_has_entries(user: KhojUser): async def auser_has_entries(user: KhojUser):
return await Entry.objects.filter(user=user).aexists() return await Entry.objects.filter(user=user).aexists()
@staticmethod
async def aagent_has_entries(agent: Agent):
return await Entry.objects.filter(agent=agent).aexists()
@staticmethod @staticmethod
async def adelete_entry_by_file(user: KhojUser, file_path: str): async def adelete_entry_by_file(user: KhojUser, file_path: str):
return await Entry.objects.filter(user=user, file_path=file_path).adelete() return await Entry.objects.filter(user=user, file_path=file_path).adelete()
@ -1319,6 +1323,10 @@ class EntryAdapters:
return deleted_count return deleted_count
@staticmethod
async def aget_agent_entry_filepaths(agent: Agent):
return await sync_to_async(list)(Entry.objects.filter(agent=agent).values_list("file_path", flat=True))
@staticmethod @staticmethod
def get_all_filenames_by_source(user: KhojUser, file_source: str): def get_all_filenames_by_source(user: KhojUser, file_source: str):
return ( return (

View file

@ -722,19 +722,30 @@ async def chat(
conversation_commands.remove(ConversationCommand.Summarize) conversation_commands.remove(ConversationCommand.Summarize)
elif ConversationCommand.Summarize in conversation_commands: elif ConversationCommand.Summarize in conversation_commands:
response_log = "" response_log = ""
if len(file_filters) == 0: agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
if len(file_filters) == 0 and not agent_has_entries:
response_log = "No files selected for summarization. Please add files using the section on the left." response_log = "No files selected for summarization. Please add files using the section on the left."
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log):
yield result yield result
elif len(file_filters) > 1: elif len(file_filters) > 1 and not agent_has_entries:
response_log = "Only one file can be selected for summarization." response_log = "Only one file can be selected for summarization."
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log):
yield result yield result
else: else:
try: try:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) file_object = None
if await EntryAdapters.aagent_has_entries(agent):
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
if len(file_names) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(
None, file_names[0], agent
)
if len(file_filters) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0: if len(file_object) == 0:
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." response_log = "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log):
yield result yield result
return return
@ -753,7 +764,7 @@ async def chat(
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log):
yield result yield result
except Exception as e: except Exception as e:
response_log = "Error summarizing file." response_log = "Error summarizing file. Please try again, or contact support."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log):
yield result yield result