mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Add support for relaying attached files through backend calls to models
This commit is contained in:
parent
a0480d5f6c
commit
de73cbc610
8 changed files with 136 additions and 62 deletions
|
@ -37,6 +37,7 @@ def extract_questions_anthropic(
|
|||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
attached_files: str = None,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -84,7 +85,12 @@ def extract_questions_anthropic(
|
|||
vision_enabled=vision_enabled,
|
||||
)
|
||||
|
||||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
messages = []
|
||||
|
||||
if attached_files:
|
||||
messages.append(ChatMessage(content=attached_files, role="user"))
|
||||
|
||||
messages.append(ChatMessage(content=prompt, role="user"))
|
||||
|
||||
response = anthropic_completion_with_backoff(
|
||||
messages=messages,
|
||||
|
|
|
@ -38,6 +38,7 @@ def extract_questions_gemini(
|
|||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
attached_files: str = None,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -85,7 +86,13 @@ def extract_questions_gemini(
|
|||
vision_enabled=vision_enabled,
|
||||
)
|
||||
|
||||
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
||||
messages = []
|
||||
|
||||
if attached_files:
|
||||
messages.append(ChatMessage(content=attached_files, role="user"))
|
||||
|
||||
messages.append(ChatMessage(content=prompt, role="user"))
|
||||
messages.append(ChatMessage(content=system_prompt, role="system"))
|
||||
|
||||
response = gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
||||
|
|
|
@ -35,6 +35,7 @@ def extract_questions(
|
|||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
tracer: dict = {},
|
||||
attached_files: str = None,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -81,7 +82,12 @@ def extract_questions(
|
|||
vision_enabled=vision_enabled,
|
||||
)
|
||||
|
||||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
messages = []
|
||||
|
||||
if attached_files:
|
||||
messages.append(ChatMessage(content=attached_files, role="user"))
|
||||
|
||||
messages.append(ChatMessage(content=prompt, role="user"))
|
||||
|
||||
response = send_message_to_model(
|
||||
messages,
|
||||
|
|
|
@ -36,6 +36,7 @@ from khoj.utils.helpers import (
|
|||
is_none_or_empty,
|
||||
merge_dicts,
|
||||
)
|
||||
from khoj.utils.rawconfig import FileAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -137,25 +138,6 @@ def construct_iteration_history(
|
|||
return previous_iterations_history
|
||||
|
||||
|
||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||
chat_history = ""
|
||||
for chat in conversation_history.get("chat", [])[-n:]:
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
|
||||
if chat["intent"].get("inferred-queries"):
|
||||
chat_history += f'Khoj: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
|
||||
|
||||
chat_history += f"{agent_name}: {chat['message']}\n\n"
|
||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
||||
return chat_history
|
||||
|
||||
|
||||
def construct_tool_chat_history(
|
||||
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
|
||||
) -> Dict[str, list]:
|
||||
|
@ -241,6 +223,7 @@ def save_to_conversation_log(
|
|||
conversation_id: str = None,
|
||||
automation_id: str = None,
|
||||
query_images: List[str] = None,
|
||||
raw_attached_files: List[FileAttachment] = [],
|
||||
tracer: Dict[str, Any] = {},
|
||||
train_of_thought: List[Any] = [],
|
||||
):
|
||||
|
@ -253,6 +236,7 @@ def save_to_conversation_log(
|
|||
"created": user_message_time,
|
||||
"images": query_images,
|
||||
"turnId": turn_id,
|
||||
"attachedFiles": [file.model_dump(mode="json") for file in raw_attached_files],
|
||||
},
|
||||
khoj_message_metadata={
|
||||
"context": compiled_references,
|
||||
|
@ -306,6 +290,22 @@ def construct_structured_message(message: str, images: list[str], model_type: st
|
|||
return message
|
||||
|
||||
|
||||
def gather_raw_attached_files(
|
||||
attached_files: Dict[str, str],
|
||||
):
|
||||
"""_summary_
|
||||
Gather contextual data from the given (raw) files
|
||||
"""
|
||||
|
||||
if len(attached_files) == 0:
|
||||
return ""
|
||||
|
||||
contextual_data = " ".join(
|
||||
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()]
|
||||
)
|
||||
return f"I have attached the following files:\n\n{contextual_data}"
|
||||
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
user_message,
|
||||
system_message=None,
|
||||
|
@ -335,6 +335,8 @@ def generate_chatml_messages_with_context(
|
|||
chatml_messages: List[ChatMessage] = []
|
||||
for chat in conversation_log.get("chat", []):
|
||||
message_context = ""
|
||||
message_attached_files = ""
|
||||
|
||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
||||
message_context += chat.get("intent").get("inferred-queries")[0]
|
||||
if not is_none_or_empty(chat.get("context")):
|
||||
|
@ -343,6 +345,15 @@ def generate_chatml_messages_with_context(
|
|||
)
|
||||
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
|
||||
|
||||
if chat.get("attachedFiles"):
|
||||
raw_attached_files = chat.get("attachedFiles")
|
||||
attached_files_dict = dict()
|
||||
for file in raw_attached_files:
|
||||
attached_files_dict[file["name"]] = file["content"]
|
||||
|
||||
message_attached_files = gather_raw_attached_files(attached_files_dict)
|
||||
chatml_messages.append(ChatMessage(content=message_attached_files, role="user"))
|
||||
|
||||
if not is_none_or_empty(chat.get("onlineContext")):
|
||||
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
|
||||
|
||||
|
|
|
@ -6,18 +6,12 @@ import os
|
|||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from khoj.database.adapters import ais_user_subscribed
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
clean_code_python,
|
||||
clean_json,
|
||||
construct_chat_history,
|
||||
)
|
||||
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||
from khoj.processor.conversation.utils import ChatEvent, clean_code_python, clean_json
|
||||
from khoj.routers.helpers import construct_chat_history, send_message_to_model_wrapper
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
|
|
|
@ -91,7 +91,6 @@ from khoj.processor.conversation.utils import (
|
|||
ChatEvent,
|
||||
ThreadedGenerator,
|
||||
clean_json,
|
||||
construct_chat_history,
|
||||
generate_chatml_messages_with_context,
|
||||
save_to_conversation_log,
|
||||
)
|
||||
|
@ -104,6 +103,7 @@ from khoj.utils.config import OfflineChatProcessorModel
|
|||
from khoj.utils.helpers import (
|
||||
LRU,
|
||||
ConversationCommand,
|
||||
get_file_type,
|
||||
is_none_or_empty,
|
||||
is_valid_url,
|
||||
log_telemetry,
|
||||
|
@ -111,7 +111,7 @@ from khoj.utils.helpers import (
|
|||
timer,
|
||||
tool_descriptions_for_llm,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -167,6 +167,12 @@ async def is_ready_to_chat(user: KhojUser):
|
|||
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
||||
|
||||
|
||||
def get_file_content(file: UploadFile):
|
||||
file_content = file.file.read()
|
||||
file_type, encoding = get_file_type(file.content_type, file_content)
|
||||
return FileData(name=file.filename, content=file_content, file_type=file_type, encoding=encoding)
|
||||
|
||||
|
||||
def update_telemetry_state(
|
||||
request: Request,
|
||||
telemetry_type: str,
|
||||
|
@ -248,23 +254,49 @@ async def agenerate_chat_response(*args):
|
|||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
||||
|
||||
|
||||
async def gather_attached_files(
|
||||
user: KhojUser,
|
||||
file_filters: List[str],
|
||||
) -> str:
|
||||
def gather_raw_attached_files(
|
||||
attached_files: Dict[str, str],
|
||||
):
|
||||
"""_summary_
|
||||
Gather contextual data from the given (raw) files
|
||||
"""
|
||||
Gather contextual data from the given files
|
||||
"""
|
||||
if len(file_filters) == 0:
|
||||
|
||||
if len(attached_files) == 0:
|
||||
return ""
|
||||
|
||||
file_objects = await FileObjectAdapters.async_get_file_objects_by_names(user, file_filters)
|
||||
contextual_data = " ".join(
|
||||
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()]
|
||||
)
|
||||
return f"I have attached the following files:\n\n{contextual_data}"
|
||||
|
||||
if len(file_objects) == 0:
|
||||
return ""
|
||||
|
||||
contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects])
|
||||
return contextual_data
|
||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||
chat_history = ""
|
||||
for chat in conversation_history.get("chat", [])[-n:]:
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
|
||||
if chat["intent"].get("inferred-queries"):
|
||||
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
|
||||
|
||||
chat_history += f"{agent_name}: {chat['message']}\n\n"
|
||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
||||
elif chat["by"] == "you":
|
||||
raw_attached_files = chat.get("attachedFiles")
|
||||
if raw_attached_files:
|
||||
attached_files: Dict[str, str] = {}
|
||||
for file in raw_attached_files:
|
||||
attached_files[file["name"]] = file["content"]
|
||||
|
||||
attached_file_context = gather_raw_attached_files(attached_files)
|
||||
chat_history += f"User: {attached_file_context}\n"
|
||||
|
||||
return chat_history
|
||||
|
||||
|
||||
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
||||
|
@ -1179,6 +1211,7 @@ def generate_chat_response(
|
|||
tracer: dict = {},
|
||||
train_of_thought: List[Any] = [],
|
||||
attached_files: str = None,
|
||||
raw_attached_files: List[FileAttachment] = None,
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response = None
|
||||
|
@ -1204,6 +1237,7 @@ def generate_chat_response(
|
|||
query_images=query_images,
|
||||
tracer=tracer,
|
||||
train_of_thought=train_of_thought,
|
||||
raw_attached_files=raw_attached_files,
|
||||
)
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
|
@ -1299,6 +1333,7 @@ def generate_chat_response(
|
|||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
query_images=query_images,
|
||||
vision_available=vision_available,
|
||||
tracer=tracer,
|
||||
attached_files=attached_files,
|
||||
|
@ -1313,23 +1348,6 @@ def generate_chat_response(
|
|||
return chat_response, metadata
|
||||
|
||||
|
||||
class ChatRequestBody(BaseModel):
|
||||
q: str
|
||||
n: Optional[int] = 7
|
||||
d: Optional[float] = None
|
||||
stream: Optional[bool] = False
|
||||
title: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
turn_id: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
images: Optional[list[str]] = None
|
||||
create_new: Optional[bool] = False
|
||||
|
||||
|
||||
class DeleteMessageRequestBody(BaseModel):
|
||||
conversation_id: str
|
||||
turn_id: str
|
||||
|
|
|
@ -20,7 +20,6 @@ from khoj.routers.api import extract_references_and_questions
|
|||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
construct_chat_history,
|
||||
extract_relevant_info,
|
||||
generate_summary_from_files,
|
||||
send_message_to_model_wrapper,
|
||||
)
|
||||
|
@ -187,6 +186,7 @@ async def execute_information_collection(
|
|||
query_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
attached_files=attached_files,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
|
|
@ -138,6 +138,38 @@ class SearchResponse(ConfigBase):
|
|||
corpus_id: str
|
||||
|
||||
|
||||
class FileData(BaseModel):
|
||||
name: str
|
||||
content: bytes
|
||||
file_type: str
|
||||
encoding: str | None = None
|
||||
|
||||
|
||||
class FileAttachment(BaseModel):
|
||||
name: str
|
||||
content: str
|
||||
file_type: str
|
||||
size: int
|
||||
|
||||
|
||||
class ChatRequestBody(BaseModel):
|
||||
q: str
|
||||
n: Optional[int] = 7
|
||||
d: Optional[float] = None
|
||||
stream: Optional[bool] = False
|
||||
title: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
turn_id: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
images: Optional[list[str]] = None
|
||||
files: Optional[list[FileAttachment]] = None
|
||||
create_new: Optional[bool] = False
|
||||
|
||||
|
||||
class Entry:
|
||||
raw: str
|
||||
compiled: str
|
||||
|
|
Loading…
Reference in a new issue