Add support for relaying attached files through backend calls to models

This commit is contained in:
sabaimran 2024-11-07 15:58:52 -08:00
parent a0480d5f6c
commit de73cbc610
8 changed files with 136 additions and 62 deletions

View file

@ -37,6 +37,7 @@ def extract_questions_anthropic(
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -84,7 +85,12 @@ def extract_questions_anthropic(
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
) )
messages = [ChatMessage(content=prompt, role="user")] messages = []
if attached_files:
messages.append(ChatMessage(content=attached_files, role="user"))
messages.append(ChatMessage(content=prompt, role="user"))
response = anthropic_completion_with_backoff( response = anthropic_completion_with_backoff(
messages=messages, messages=messages,

View file

@ -38,6 +38,7 @@ def extract_questions_gemini(
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -85,7 +86,13 @@ def extract_questions_gemini(
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
) )
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")] messages = []
if attached_files:
messages.append(ChatMessage(content=attached_files, role="user"))
messages.append(ChatMessage(content=prompt, role="user"))
messages.append(ChatMessage(content=system_prompt, role="system"))
response = gemini_send_message_to_model( response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer

View file

@ -35,6 +35,7 @@ def extract_questions(
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
tracer: dict = {}, tracer: dict = {},
attached_files: str = None,
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -81,7 +82,12 @@ def extract_questions(
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
) )
messages = [ChatMessage(content=prompt, role="user")] messages = []
if attached_files:
messages.append(ChatMessage(content=attached_files, role="user"))
messages.append(ChatMessage(content=prompt, role="user"))
response = send_message_to_model( response = send_message_to_model(
messages, messages,

View file

@ -36,6 +36,7 @@ from khoj.utils.helpers import (
is_none_or_empty, is_none_or_empty,
merge_dicts, merge_dicts,
) )
from khoj.utils.rawconfig import FileAttachment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -137,25 +138,6 @@ def construct_iteration_history(
return previous_iterations_history return previous_iterations_history
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
chat_history += f"User: {chat['intent']['query']}\n"
if chat["intent"].get("inferred-queries"):
chat_history += f'Khoj: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
chat_history += f"{agent_name}: {chat['message']}\n\n"
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: [generated image redacted for space]\n"
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
return chat_history
def construct_tool_chat_history( def construct_tool_chat_history(
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
) -> Dict[str, list]: ) -> Dict[str, list]:
@ -241,6 +223,7 @@ def save_to_conversation_log(
conversation_id: str = None, conversation_id: str = None,
automation_id: str = None, automation_id: str = None,
query_images: List[str] = None, query_images: List[str] = None,
raw_attached_files: List[FileAttachment] = [],
tracer: Dict[str, Any] = {}, tracer: Dict[str, Any] = {},
train_of_thought: List[Any] = [], train_of_thought: List[Any] = [],
): ):
@ -253,6 +236,7 @@ def save_to_conversation_log(
"created": user_message_time, "created": user_message_time,
"images": query_images, "images": query_images,
"turnId": turn_id, "turnId": turn_id,
"attachedFiles": [file.model_dump(mode="json") for file in raw_attached_files],
}, },
khoj_message_metadata={ khoj_message_metadata={
"context": compiled_references, "context": compiled_references,
@ -306,6 +290,22 @@ def construct_structured_message(message: str, images: list[str], model_type: st
return message return message
def gather_raw_attached_files(
attached_files: Dict[str, str],
):
"""_summary_
Gather contextual data from the given (raw) files
"""
if len(attached_files) == 0:
return ""
contextual_data = " ".join(
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()]
)
return f"I have attached the following files:\n\n{contextual_data}"
def generate_chatml_messages_with_context( def generate_chatml_messages_with_context(
user_message, user_message,
system_message=None, system_message=None,
@ -335,6 +335,8 @@ def generate_chatml_messages_with_context(
chatml_messages: List[ChatMessage] = [] chatml_messages: List[ChatMessage] = []
for chat in conversation_log.get("chat", []): for chat in conversation_log.get("chat", []):
message_context = "" message_context = ""
message_attached_files = ""
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
message_context += chat.get("intent").get("inferred-queries")[0] message_context += chat.get("intent").get("inferred-queries")[0]
if not is_none_or_empty(chat.get("context")): if not is_none_or_empty(chat.get("context")):
@ -343,6 +345,15 @@ def generate_chatml_messages_with_context(
) )
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n" message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
if chat.get("attachedFiles"):
raw_attached_files = chat.get("attachedFiles")
attached_files_dict = dict()
for file in raw_attached_files:
attached_files_dict[file["name"]] = file["content"]
message_attached_files = gather_raw_attached_files(attached_files_dict)
chatml_messages.append(ChatMessage(content=message_attached_files, role="user"))
if not is_none_or_empty(chat.get("onlineContext")): if not is_none_or_empty(chat.get("onlineContext")):
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"

View file

@ -6,18 +6,12 @@ import os
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional
import aiohttp import aiohttp
import requests
from khoj.database.adapters import ais_user_subscribed from khoj.database.adapters import ais_user_subscribed
from khoj.database.models import Agent, KhojUser from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import ChatEvent, clean_code_python, clean_json
ChatEvent, from khoj.routers.helpers import construct_chat_history, send_message_to_model_wrapper
clean_code_python,
clean_json,
construct_chat_history,
)
from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData

View file

@ -91,7 +91,6 @@ from khoj.processor.conversation.utils import (
ChatEvent, ChatEvent,
ThreadedGenerator, ThreadedGenerator,
clean_json, clean_json,
construct_chat_history,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
save_to_conversation_log, save_to_conversation_log,
) )
@ -104,6 +103,7 @@ from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import ( from khoj.utils.helpers import (
LRU, LRU,
ConversationCommand, ConversationCommand,
get_file_type,
is_none_or_empty, is_none_or_empty,
is_valid_url, is_valid_url,
log_telemetry, log_telemetry,
@ -111,7 +111,7 @@ from khoj.utils.helpers import (
timer, timer,
tool_descriptions_for_llm, tool_descriptions_for_llm,
) )
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -167,6 +167,12 @@ async def is_ready_to_chat(user: KhojUser):
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.") raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
def get_file_content(file: UploadFile):
file_content = file.file.read()
file_type, encoding = get_file_type(file.content_type, file_content)
return FileData(name=file.filename, content=file_content, file_type=file_type, encoding=encoding)
def update_telemetry_state( def update_telemetry_state(
request: Request, request: Request,
telemetry_type: str, telemetry_type: str,
@ -248,23 +254,49 @@ async def agenerate_chat_response(*args):
return await loop.run_in_executor(executor, generate_chat_response, *args) return await loop.run_in_executor(executor, generate_chat_response, *args)
async def gather_attached_files( def gather_raw_attached_files(
user: KhojUser, attached_files: Dict[str, str],
file_filters: List[str], ):
) -> str: """_summary_
Gather contextual data from the given (raw) files
""" """
Gather contextual data from the given files
""" if len(attached_files) == 0:
if len(file_filters) == 0:
return "" return ""
file_objects = await FileObjectAdapters.async_get_file_objects_by_names(user, file_filters) contextual_data = " ".join(
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()]
)
return f"I have attached the following files:\n\n{contextual_data}"
if len(file_objects) == 0:
return ""
contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects]) def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
return contextual_data chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
chat_history += f"User: {chat['intent']['query']}\n"
if chat["intent"].get("inferred-queries"):
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
chat_history += f"{agent_name}: {chat['message']}\n\n"
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: [generated image redacted for space]\n"
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
elif chat["by"] == "you":
raw_attached_files = chat.get("attachedFiles")
if raw_attached_files:
attached_files: Dict[str, str] = {}
for file in raw_attached_files:
attached_files[file["name"]] = file["content"]
attached_file_context = gather_raw_attached_files(attached_files)
chat_history += f"User: {attached_file_context}\n"
return chat_history
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
@ -1179,6 +1211,7 @@ def generate_chat_response(
tracer: dict = {}, tracer: dict = {},
train_of_thought: List[Any] = [], train_of_thought: List[Any] = [],
attached_files: str = None, attached_files: str = None,
raw_attached_files: List[FileAttachment] = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables # Initialize Variables
chat_response = None chat_response = None
@ -1204,6 +1237,7 @@ def generate_chat_response(
query_images=query_images, query_images=query_images,
tracer=tracer, tracer=tracer,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
raw_attached_files=raw_attached_files,
) )
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
@ -1299,6 +1333,7 @@ def generate_chat_response(
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
query_images=query_images,
vision_available=vision_available, vision_available=vision_available,
tracer=tracer, tracer=tracer,
attached_files=attached_files, attached_files=attached_files,
@ -1313,23 +1348,6 @@ def generate_chat_response(
return chat_response, metadata return chat_response, metadata
class ChatRequestBody(BaseModel):
q: str
n: Optional[int] = 7
d: Optional[float] = None
stream: Optional[bool] = False
title: Optional[str] = None
conversation_id: Optional[str] = None
turn_id: Optional[str] = None
city: Optional[str] = None
region: Optional[str] = None
country: Optional[str] = None
country_code: Optional[str] = None
timezone: Optional[str] = None
images: Optional[list[str]] = None
create_new: Optional[bool] = False
class DeleteMessageRequestBody(BaseModel): class DeleteMessageRequestBody(BaseModel):
conversation_id: str conversation_id: str
turn_id: str turn_id: str

View file

@ -20,7 +20,6 @@ from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ChatEvent, ChatEvent,
construct_chat_history, construct_chat_history,
extract_relevant_info,
generate_summary_from_files, generate_summary_from_files,
send_message_to_model_wrapper, send_message_to_model_wrapper,
) )
@ -187,6 +186,7 @@ async def execute_information_collection(
query_images, query_images,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
attached_files=attached_files,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]

View file

@ -138,6 +138,38 @@ class SearchResponse(ConfigBase):
corpus_id: str corpus_id: str
class FileData(BaseModel):
name: str
content: bytes
file_type: str
encoding: str | None = None
class FileAttachment(BaseModel):
name: str
content: str
file_type: str
size: int
class ChatRequestBody(BaseModel):
q: str
n: Optional[int] = 7
d: Optional[float] = None
stream: Optional[bool] = False
title: Optional[str] = None
conversation_id: Optional[str] = None
turn_id: Optional[str] = None
city: Optional[str] = None
region: Optional[str] = None
country: Optional[str] = None
country_code: Optional[str] = None
timezone: Optional[str] = None
images: Optional[list[str]] = None
files: Optional[list[FileAttachment]] = None
create_new: Optional[bool] = False
class Entry: class Entry:
raw: str raw: str
compiled: str compiled: str