Add methods for reading full files by name and including context

Now that models have much larger context windows, we can reasonably include full texts of certain files in the messages. Do this when an explicit file filter is set in a conversation. Do so in a separate user message in order to mitigate any confusion in the operation.

Pipe the relevant attached_files context through all methods calling into models.

We'll want to limit the file sizes for which this is used and provide more helpful UI indicators that this sort of behavior is taking place.
This commit is contained in:
sabaimran 2024-11-04 16:37:13 -08:00
parent e3ca52b7cb
commit 362bdebd02
12 changed files with 142 additions and 13 deletions

View file

@ -1387,6 +1387,10 @@ class FileObjectAdapters:
async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
@staticmethod
async def async_get_file_objects_by_names(user: KhojUser, file_names: List[str]):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
@staticmethod @staticmethod
async def async_get_all_file_objects(user: KhojUser): async def async_get_all_file_objects(user: KhojUser):
return await sync_to_async(list)(FileObject.objects.filter(user=user)) return await sync_to_async(list)(FileObject.objects.filter(user=user))

View file

@ -147,6 +147,7 @@ def converse_anthropic(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
""" """
Converse with user using Anthropic's Claude Converse with user using Anthropic's Claude
@ -203,6 +204,7 @@ def converse_anthropic(
query_images=query_images, query_images=query_images,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC, model_type=ChatModelOptions.ModelType.ANTHROPIC,
attached_files=attached_files,
) )
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)

View file

@ -108,7 +108,14 @@ def extract_questions_gemini(
def gemini_send_message_to_model( def gemini_send_message_to_model(
messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={} messages,
api_key,
model,
response_type="text",
temperature=0,
model_kwargs=None,
tracer={},
attached_files: str = None,
): ):
""" """
Send message to model Send message to model
@ -152,6 +159,7 @@ def converse_gemini(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
tracer={}, tracer={},
attached_files: str = None,
): ):
""" """
Converse with user using Google's Gemini Converse with user using Google's Gemini
@ -209,6 +217,7 @@ def converse_gemini(
query_images=query_images, query_images=query_images,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE, model_type=ChatModelOptions.ModelType.GOOGLE,
attached_files=attached_files,
) )
messages, system_prompt = format_messages_for_gemini(messages, system_prompt) messages, system_prompt = format_messages_for_gemini(messages, system_prompt)

View file

@ -38,6 +38,7 @@ def extract_questions_offline(
temperature: float = 0.7, temperature: float = 0.7,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
) -> List[str]: ) -> List[str]:
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -87,6 +88,7 @@ def extract_questions_offline(
loaded_model=offline_chat_model, loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
model_type=ChatModelOptions.ModelType.OFFLINE, model_type=ChatModelOptions.ModelType.OFFLINE,
attached_files=attached_files,
) )
state.chat_lock.acquire() state.chat_lock.acquire()
@ -153,6 +155,7 @@ def converse_offline(
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
) -> Union[ThreadedGenerator, Iterator[str]]: ) -> Union[ThreadedGenerator, Iterator[str]]:
""" """
Converse with user using Llama Converse with user using Llama
@ -216,6 +219,7 @@ def converse_offline(
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.OFFLINE, model_type=ChatModelOptions.ModelType.OFFLINE,
attached_files=attached_files,
) )
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})

View file

@ -149,6 +149,7 @@ def converse(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
""" """
Converse with user using OpenAI's ChatGPT Converse with user using OpenAI's ChatGPT
@ -206,6 +207,7 @@ def converse(
query_images=query_images, query_images=query_images,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModelOptions.ModelType.OPENAI,
attached_files=attached_files,
) )
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}") logger.debug(f"Conversation Context for GPT: {truncated_messages}")

View file

@ -318,6 +318,7 @@ def generate_chatml_messages_with_context(
vision_enabled=False, vision_enabled=False,
model_type="", model_type="",
context_message="", context_message="",
attached_files: str = None,
): ):
"""Generate chat messages with appropriate context from previous conversation to send to the chat model""" """Generate chat messages with appropriate context from previous conversation to send to the chat model"""
# Set max prompt size from user config or based on pre-configured for model and machine specs # Set max prompt size from user config or based on pre-configured for model and machine specs
@ -341,8 +342,10 @@ def generate_chatml_messages_with_context(
{f"# File: {item['file']}\n## {item['compiled']}\n" for item in chat.get("context") or []} {f"# File: {item['file']}\n## {item['compiled']}\n" for item in chat.get("context") or []}
) )
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n" message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
if not is_none_or_empty(chat.get("onlineContext")): if not is_none_or_empty(chat.get("onlineContext")):
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
if not is_none_or_empty(message_context): if not is_none_or_empty(message_context):
reconstructed_context_message = ChatMessage(content=message_context, role="user") reconstructed_context_message = ChatMessage(content=message_context, role="user")
chatml_messages.insert(0, reconstructed_context_message) chatml_messages.insert(0, reconstructed_context_message)
@ -366,8 +369,13 @@ def generate_chatml_messages_with_context(
) )
if not is_none_or_empty(context_message): if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user")) messages.append(ChatMessage(content=context_message, role="user"))
if not is_none_or_empty(attached_files):
messages.append(ChatMessage(content=attached_files, role="user"))
if len(chatml_messages) > 0: if len(chatml_messages) > 0:
messages += chatml_messages messages += chatml_messages
if not is_none_or_empty(system_message): if not is_none_or_empty(system_message):
messages.append(ChatMessage(content=system_message, role="system")) messages.append(ChatMessage(content=system_message, role="system"))

View file

@ -29,6 +29,7 @@ async def text_to_image(
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
status_code = 200 status_code = 200
image = None image = None
@ -70,6 +71,7 @@ async def text_to_image(
user=user, user=user,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
) )
if send_status_func: if send_status_func:

View file

@ -68,6 +68,7 @@ async def search_online(
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
query += " ".join(custom_filters) query += " ".join(custom_filters)
if not is_internet_connected(): if not is_internet_connected():
@ -77,7 +78,14 @@ async def search_online(
# Breakdown the query into subqueries to get the correct answer # Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries( subqueries = await generate_online_subqueries(
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer query,
conversation_history,
location,
user,
query_images=query_images,
agent=agent,
tracer=tracer,
attached_files=attached_files,
) )
response_dict = {} response_dict = {}
@ -159,11 +167,19 @@ async def read_webpages(
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ, max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
attached_files: str = None,
): ):
"Infer web pages to read from the query and extract relevant information from them" "Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read") logger.info(f"Inferring web pages to read")
urls = await infer_webpage_urls( urls = await infer_webpage_urls(
query, conversation_history, location, user, query_images, agent=agent, tracer=tracer query,
conversation_history,
location,
user,
query_images,
agent=agent,
tracer=tracer,
attached_files=attached_files,
) )
# Get the top 10 web pages to read # Get the top 10 web pages to read

View file

@ -6,6 +6,7 @@ import os
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional
import aiohttp import aiohttp
import requests
from khoj.database.adapters import ais_user_subscribed from khoj.database.adapters import ais_user_subscribed
from khoj.database.models import Agent, KhojUser from khoj.database.models import Agent, KhojUser
@ -37,6 +38,7 @@ async def run_code(
agent: Agent = None, agent: Agent = None,
sandbox_url: str = SANDBOX_URL, sandbox_url: str = SANDBOX_URL,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
# Generate Code # Generate Code
if send_status_func: if send_status_func:
@ -53,6 +55,7 @@ async def run_code(
query_images, query_images,
agent, agent,
tracer, tracer,
attached_files,
) )
except Exception as e: except Exception as e:
raise ValueError(f"Failed to generate code for {query} with error: {e}") raise ValueError(f"Failed to generate code for {query} with error: {e}")
@ -82,6 +85,7 @@ async def generate_python_code(
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
) -> List[str]: ) -> List[str]:
location = f"{location_data}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
@ -109,6 +113,7 @@ async def generate_python_code(
response_type="json_object", response_type="json_object",
user=user, user=user,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list

View file

@ -19,7 +19,6 @@ from khoj.database.adapters import (
AgentAdapters, AgentAdapters,
ConversationAdapters, ConversationAdapters,
EntryAdapters, EntryAdapters,
FileObjectAdapters,
PublicConversationAdapters, PublicConversationAdapters,
aget_user_name, aget_user_name,
) )
@ -46,7 +45,7 @@ from khoj.routers.helpers import (
aget_relevant_output_modes, aget_relevant_output_modes,
construct_automation_created_message, construct_automation_created_message,
create_automation, create_automation,
extract_relevant_info, gather_attached_files,
generate_excalidraw_diagram, generate_excalidraw_diagram,
generate_summary_from_files, generate_summary_from_files,
get_conversation_command, get_conversation_command,
@ -707,6 +706,8 @@ async def chat(
## Extract Document References ## Extract Document References
compiled_references: List[Any] = [] compiled_references: List[Any] = []
inferred_queries: List[Any] = [] inferred_queries: List[Any] = []
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
attached_file_context = await gather_attached_files(user, file_filters)
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources( conversation_commands = await aget_relevant_information_sources(
@ -717,6 +718,7 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
attached_files=attached_file_context,
) )
# If we're doing research, we don't want to do anything else # If we're doing research, we don't want to do anything else
@ -757,6 +759,7 @@ async def chat(
location=location, location=location,
file_filters=conversation.file_filters if conversation else [], file_filters=conversation.file_filters if conversation else [],
tracer=tracer, tracer=tracer,
attached_files=attached_file_context,
): ):
if isinstance(research_result, InformationCollectionIteration): if isinstance(research_result, InformationCollectionIteration):
if research_result.summarizedResult: if research_result.summarizedResult:
@ -812,6 +815,7 @@ async def chat(
agent=agent, agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
tracer=tracer, tracer=tracer,
attached_files=attached_file_context,
): ):
if isinstance(response, dict) and ChatEvent.STATUS in response: if isinstance(response, dict) and ChatEvent.STATUS in response:
yield response[ChatEvent.STATUS] yield response[ChatEvent.STATUS]
@ -945,6 +949,7 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
attached_files=attached_file_context,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -970,6 +975,7 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
attached_files=attached_file_context,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -1010,6 +1016,7 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
attached_files=attached_file_context,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -1049,6 +1056,7 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
attached_files=attached_file_context,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -1110,6 +1118,7 @@ async def chat(
agent=agent, agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
tracer=tracer, tracer=tracer,
attached_files=attached_file_context,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -1166,6 +1175,7 @@ async def chat(
uploaded_images, uploaded_images,
tracer, tracer,
train_of_thought, train_of_thought,
attached_file_context,
) )
# Send Response # Send Response

View file

@ -248,6 +248,25 @@ async def agenerate_chat_response(*args):
return await loop.run_in_executor(executor, generate_chat_response, *args) return await loop.run_in_executor(executor, generate_chat_response, *args)
async def gather_attached_files(
user: KhojUser,
file_filters: List[str],
) -> str:
"""
Gather contextual data from the given files
"""
if len(file_filters) == 0:
return ""
file_objects = await FileObjectAdapters.async_get_file_objects_by_names(user, file_filters)
if len(file_objects) == 0:
return ""
contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects])
return contextual_data
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
""" """
Create a title from the given query Create a title from the given query
@ -294,6 +313,7 @@ async def aget_relevant_information_sources(
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
""" """
Given a query, determine which of the available tools the agent should use in order to answer appropriately. Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -331,6 +351,7 @@ async def aget_relevant_information_sources(
response_type="json_object", response_type="json_object",
user=user, user=user,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
) )
try: try:
@ -440,6 +461,7 @@ async def infer_webpage_urls(
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
) -> List[str]: ) -> List[str]:
""" """
Infer webpage links from the given query Infer webpage links from the given query
@ -469,6 +491,7 @@ async def infer_webpage_urls(
response_type="json_object", response_type="json_object",
user=user, user=user,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
) )
# Validate that the response is a non-empty, JSON-serializable list of URLs # Validate that the response is a non-empty, JSON-serializable list of URLs
@ -494,6 +517,7 @@ async def generate_online_subqueries(
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
) -> List[str]: ) -> List[str]:
""" """
Generate subqueries from the given query Generate subqueries from the given query
@ -523,6 +547,7 @@ async def generate_online_subqueries(
response_type="json_object", response_type="json_object",
user=user, user=user,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
@ -645,6 +670,7 @@ async def generate_summary_from_files(
agent: Agent = None, agent: Agent = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
try: try:
file_object = None file_object = None
@ -653,17 +679,28 @@ async def generate_summary_from_files(
if len(file_names) > 0: if len(file_names) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent)
if len(file_filters) > 0: if len(file_object) == 0 and not attached_files:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, I couldn't find the full text of this file." response_log = "Sorry, I couldn't find the full text of this file."
yield response_log yield response_log
return return
contextual_data = " ".join([file.raw_text for file in file_object])
contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_object])
if attached_files:
contextual_data += f"\n\n{attached_files}"
if not q: if not q:
q = "Create a general summary of the file" q = "Create a general summary of the file"
async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"):
file_names = [file.file_name for file in file_object]
file_names.extend(file_filters)
all_file_names = ""
for file_name in file_names:
all_file_names += f"- {file_name}\n"
async for result in send_status_func(f"**Constructing Summary Using:**\n{all_file_names}"):
yield {ChatEvent.STATUS: result} yield {ChatEvent.STATUS: result}
response = await extract_relevant_summary( response = await extract_relevant_summary(
@ -694,6 +731,7 @@ async def generate_excalidraw_diagram(
agent: Agent = None, agent: Agent = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
if send_status_func: if send_status_func:
async for event in send_status_func("**Enhancing the Diagramming Prompt**"): async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
@ -709,6 +747,7 @@ async def generate_excalidraw_diagram(
user=user, user=user,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
) )
if send_status_func: if send_status_func:
@ -735,6 +774,7 @@ async def generate_better_diagram_description(
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
) -> str: ) -> str:
""" """
Generate a diagram description from the given query and context Generate a diagram description from the given query and context
@ -772,7 +812,11 @@ async def generate_better_diagram_description(
with timer("Chat actor: Generate better diagram description", logger): with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer improve_diagram_description_prompt,
query_images=query_images,
user=user,
tracer=tracer,
attached_files=attached_files,
) )
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@ -820,6 +864,7 @@ async def generate_better_image_prompt(
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = "",
) -> str: ) -> str:
""" """
Generate a better image prompt from the given query Generate a better image prompt from the given query
@ -867,7 +912,7 @@ async def generate_better_image_prompt(
with timer("Chat actor: Generate contextual image prompt", logger): with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
image_prompt, query_images=query_images, user=user, tracer=tracer image_prompt, query_images=query_images, user=user, tracer=tracer, attached_files=attached_files
) )
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@ -884,6 +929,7 @@ async def send_message_to_model_wrapper(
query_images: List[str] = None, query_images: List[str] = None,
context: str = "", context: str = "",
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled vision_available = conversation_config.vision_enabled
@ -922,6 +968,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
attached_files=attached_files,
) )
return send_message_to_model_offline( return send_message_to_model_offline(
@ -948,6 +995,7 @@ async def send_message_to_model_wrapper(
vision_enabled=vision_available, vision_enabled=vision_available,
query_images=query_images, query_images=query_images,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
attached_files=attached_files,
) )
return send_message_to_model( return send_message_to_model(
@ -970,6 +1018,7 @@ async def send_message_to_model_wrapper(
vision_enabled=vision_available, vision_enabled=vision_available,
query_images=query_images, query_images=query_images,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
attached_files=attached_files,
) )
return anthropic_send_message_to_model( return anthropic_send_message_to_model(
@ -991,6 +1040,7 @@ async def send_message_to_model_wrapper(
vision_enabled=vision_available, vision_enabled=vision_available,
query_images=query_images, query_images=query_images,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
attached_files=attached_files,
) )
return gemini_send_message_to_model( return gemini_send_message_to_model(
@ -1006,6 +1056,7 @@ def send_message_to_model_wrapper_sync(
response_type: str = "text", response_type: str = "text",
user: KhojUser = None, user: KhojUser = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = "",
): ):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
@ -1029,6 +1080,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
attached_files=attached_files,
) )
return send_message_to_model_offline( return send_message_to_model_offline(
@ -1050,6 +1102,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
attached_files=attached_files,
) )
openai_response = send_message_to_model( openai_response = send_message_to_model(
@ -1071,6 +1124,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
attached_files=attached_files,
) )
return anthropic_send_message_to_model( return anthropic_send_message_to_model(
@ -1090,6 +1144,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
attached_files=attached_files,
) )
return gemini_send_message_to_model( return gemini_send_message_to_model(
@ -1121,6 +1176,7 @@ def generate_chat_response(
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
tracer: dict = {}, tracer: dict = {},
train_of_thought: List[Any] = [], train_of_thought: List[Any] = [],
attached_files: str = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables # Initialize Variables
chat_response = None chat_response = None
@ -1173,6 +1229,7 @@ def generate_chat_response(
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -1198,6 +1255,7 @@ def generate_chat_response(
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
@ -1220,6 +1278,7 @@ def generate_chat_response(
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
@ -1240,6 +1299,7 @@ def generate_chat_response(
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
) )
metadata.update({"chat_model": conversation_config.chat_model}) metadata.update({"chat_model": conversation_config.chat_model})

View file

@ -47,6 +47,7 @@ async def apick_next_tool(
max_iterations: int = 5, max_iterations: int = 5,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
""" """
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer. Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
@ -95,6 +96,7 @@ async def apick_next_tool(
user=user, user=user,
query_images=query_images, query_images=query_images,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
) )
try: try:
@ -137,6 +139,7 @@ async def execute_information_collection(
location: LocationData = None, location: LocationData = None,
file_filters: List[str] = [], file_filters: List[str] = [],
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
current_iteration = 0 current_iteration = 0
MAX_ITERATIONS = 5 MAX_ITERATIONS = 5
@ -161,6 +164,7 @@ async def execute_information_collection(
MAX_ITERATIONS, MAX_ITERATIONS,
send_status_func, send_status_func,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -233,6 +237,7 @@ async def execute_information_collection(
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -264,6 +269,7 @@ async def execute_information_collection(
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -288,6 +294,7 @@ async def execute_information_collection(
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
send_status_func=send_status_func, send_status_func=send_status_func,
attached_files=attached_files,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]