mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-24 07:55:07 +01:00
Add summarization capability with agent knowledge base
This commit is contained in:
parent
df3dc33e96
commit
f700d5bddb
2 changed files with 26 additions and 7 deletions
|
@ -1214,8 +1214,8 @@ class FileObjectAdapters:
|
||||||
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
|
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def async_get_file_objects_by_name(user: KhojUser, file_name: str):
|
async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
||||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name))
|
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def async_get_all_file_objects(user: KhojUser):
|
async def async_get_all_file_objects(user: KhojUser):
|
||||||
|
@ -1305,6 +1305,10 @@ class EntryAdapters:
|
||||||
async def auser_has_entries(user: KhojUser):
|
async def auser_has_entries(user: KhojUser):
|
||||||
return await Entry.objects.filter(user=user).aexists()
|
return await Entry.objects.filter(user=user).aexists()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aagent_has_entries(agent: Agent):
|
||||||
|
return await Entry.objects.filter(agent=agent).aexists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def adelete_entry_by_file(user: KhojUser, file_path: str):
|
async def adelete_entry_by_file(user: KhojUser, file_path: str):
|
||||||
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
|
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
|
||||||
|
@ -1319,6 +1323,10 @@ class EntryAdapters:
|
||||||
|
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_agent_entry_filepaths(agent: Agent):
|
||||||
|
return await sync_to_async(list)(Entry.objects.filter(agent=agent).values_list("file_path", flat=True))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all_filenames_by_source(user: KhojUser, file_source: str):
|
def get_all_filenames_by_source(user: KhojUser, file_source: str):
|
||||||
return (
|
return (
|
||||||
|
|
|
@ -722,19 +722,30 @@ async def chat(
|
||||||
conversation_commands.remove(ConversationCommand.Summarize)
|
conversation_commands.remove(ConversationCommand.Summarize)
|
||||||
elif ConversationCommand.Summarize in conversation_commands:
|
elif ConversationCommand.Summarize in conversation_commands:
|
||||||
response_log = ""
|
response_log = ""
|
||||||
if len(file_filters) == 0:
|
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||||
|
if len(file_filters) == 0 and not agent_has_entries:
|
||||||
response_log = "No files selected for summarization. Please add files using the section on the left."
|
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log):
|
||||||
yield result
|
yield result
|
||||||
elif len(file_filters) > 1:
|
elif len(file_filters) > 1 and not agent_has_entries:
|
||||||
response_log = "Only one file can be selected for summarization."
|
response_log = "Only one file can be selected for summarization."
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log):
|
||||||
yield result
|
yield result
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
file_object = None
|
||||||
|
if await EntryAdapters.aagent_has_entries(agent):
|
||||||
|
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
||||||
|
if len(file_names) > 0:
|
||||||
|
file_object = await FileObjectAdapters.async_get_file_objects_by_name(
|
||||||
|
None, file_names[0], agent
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(file_filters) > 0:
|
||||||
|
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||||
|
|
||||||
if len(file_object) == 0:
|
if len(file_object) == 0:
|
||||||
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
response_log = "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
@ -753,7 +764,7 @@ async def chat(
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log):
|
||||||
yield result
|
yield result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
response_log = "Error summarizing file."
|
response_log = "Error summarizing file. Please try again, or contact support."
|
||||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log):
|
||||||
yield result
|
yield result
|
||||||
|
|
Loading…
Reference in a new issue