mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Add support for relaying attached files through backend calls to models
This commit is contained in:
parent
a0480d5f6c
commit
de73cbc610
8 changed files with 136 additions and 62 deletions
|
@ -37,6 +37,7 @@ def extract_questions_anthropic(
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
|
attached_files: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
@ -84,7 +85,12 @@ def extract_questions_anthropic(
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [ChatMessage(content=prompt, role="user")]
|
messages = []
|
||||||
|
|
||||||
|
if attached_files:
|
||||||
|
messages.append(ChatMessage(content=attached_files, role="user"))
|
||||||
|
|
||||||
|
messages.append(ChatMessage(content=prompt, role="user"))
|
||||||
|
|
||||||
response = anthropic_completion_with_backoff(
|
response = anthropic_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -38,6 +38,7 @@ def extract_questions_gemini(
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
|
attached_files: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
@ -85,7 +86,13 @@ def extract_questions_gemini(
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
messages = []
|
||||||
|
|
||||||
|
if attached_files:
|
||||||
|
messages.append(ChatMessage(content=attached_files, role="user"))
|
||||||
|
|
||||||
|
messages.append(ChatMessage(content=prompt, role="user"))
|
||||||
|
messages.append(ChatMessage(content=system_prompt, role="system"))
|
||||||
|
|
||||||
response = gemini_send_message_to_model(
|
response = gemini_send_message_to_model(
|
||||||
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
||||||
|
|
|
@ -35,6 +35,7 @@ def extract_questions(
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
|
attached_files: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
@ -81,7 +82,12 @@ def extract_questions(
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [ChatMessage(content=prompt, role="user")]
|
messages = []
|
||||||
|
|
||||||
|
if attached_files:
|
||||||
|
messages.append(ChatMessage(content=attached_files, role="user"))
|
||||||
|
|
||||||
|
messages.append(ChatMessage(content=prompt, role="user"))
|
||||||
|
|
||||||
response = send_message_to_model(
|
response = send_message_to_model(
|
||||||
messages,
|
messages,
|
||||||
|
|
|
@ -36,6 +36,7 @@ from khoj.utils.helpers import (
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
merge_dicts,
|
merge_dicts,
|
||||||
)
|
)
|
||||||
|
from khoj.utils.rawconfig import FileAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -137,25 +138,6 @@ def construct_iteration_history(
|
||||||
return previous_iterations_history
|
return previous_iterations_history
|
||||||
|
|
||||||
|
|
||||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
|
||||||
chat_history = ""
|
|
||||||
for chat in conversation_history.get("chat", [])[-n:]:
|
|
||||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
|
||||||
chat_history += f"User: {chat['intent']['query']}\n"
|
|
||||||
|
|
||||||
if chat["intent"].get("inferred-queries"):
|
|
||||||
chat_history += f'Khoj: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
|
|
||||||
|
|
||||||
chat_history += f"{agent_name}: {chat['message']}\n\n"
|
|
||||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
|
||||||
chat_history += f"User: {chat['intent']['query']}\n"
|
|
||||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
|
||||||
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
|
||||||
chat_history += f"User: {chat['intent']['query']}\n"
|
|
||||||
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
|
||||||
return chat_history
|
|
||||||
|
|
||||||
|
|
||||||
def construct_tool_chat_history(
|
def construct_tool_chat_history(
|
||||||
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
|
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
|
||||||
) -> Dict[str, list]:
|
) -> Dict[str, list]:
|
||||||
|
@ -241,6 +223,7 @@ def save_to_conversation_log(
|
||||||
conversation_id: str = None,
|
conversation_id: str = None,
|
||||||
automation_id: str = None,
|
automation_id: str = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
|
raw_attached_files: List[FileAttachment] = [],
|
||||||
tracer: Dict[str, Any] = {},
|
tracer: Dict[str, Any] = {},
|
||||||
train_of_thought: List[Any] = [],
|
train_of_thought: List[Any] = [],
|
||||||
):
|
):
|
||||||
|
@ -253,6 +236,7 @@ def save_to_conversation_log(
|
||||||
"created": user_message_time,
|
"created": user_message_time,
|
||||||
"images": query_images,
|
"images": query_images,
|
||||||
"turnId": turn_id,
|
"turnId": turn_id,
|
||||||
|
"attachedFiles": [file.model_dump(mode="json") for file in raw_attached_files],
|
||||||
},
|
},
|
||||||
khoj_message_metadata={
|
khoj_message_metadata={
|
||||||
"context": compiled_references,
|
"context": compiled_references,
|
||||||
|
@ -306,6 +290,22 @@ def construct_structured_message(message: str, images: list[str], model_type: st
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def gather_raw_attached_files(
|
||||||
|
attached_files: Dict[str, str],
|
||||||
|
):
|
||||||
|
"""_summary_
|
||||||
|
Gather contextual data from the given (raw) files
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(attached_files) == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
contextual_data = " ".join(
|
||||||
|
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()]
|
||||||
|
)
|
||||||
|
return f"I have attached the following files:\n\n{contextual_data}"
|
||||||
|
|
||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
user_message,
|
user_message,
|
||||||
system_message=None,
|
system_message=None,
|
||||||
|
@ -335,6 +335,8 @@ def generate_chatml_messages_with_context(
|
||||||
chatml_messages: List[ChatMessage] = []
|
chatml_messages: List[ChatMessage] = []
|
||||||
for chat in conversation_log.get("chat", []):
|
for chat in conversation_log.get("chat", []):
|
||||||
message_context = ""
|
message_context = ""
|
||||||
|
message_attached_files = ""
|
||||||
|
|
||||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
||||||
message_context += chat.get("intent").get("inferred-queries")[0]
|
message_context += chat.get("intent").get("inferred-queries")[0]
|
||||||
if not is_none_or_empty(chat.get("context")):
|
if not is_none_or_empty(chat.get("context")):
|
||||||
|
@ -343,6 +345,15 @@ def generate_chatml_messages_with_context(
|
||||||
)
|
)
|
||||||
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
|
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
|
||||||
|
|
||||||
|
if chat.get("attachedFiles"):
|
||||||
|
raw_attached_files = chat.get("attachedFiles")
|
||||||
|
attached_files_dict = dict()
|
||||||
|
for file in raw_attached_files:
|
||||||
|
attached_files_dict[file["name"]] = file["content"]
|
||||||
|
|
||||||
|
message_attached_files = gather_raw_attached_files(attached_files_dict)
|
||||||
|
chatml_messages.append(ChatMessage(content=message_attached_files, role="user"))
|
||||||
|
|
||||||
if not is_none_or_empty(chat.get("onlineContext")):
|
if not is_none_or_empty(chat.get("onlineContext")):
|
||||||
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
|
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
|
||||||
|
|
||||||
|
|
|
@ -6,18 +6,12 @@ import os
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
|
||||||
|
|
||||||
from khoj.database.adapters import ais_user_subscribed
|
from khoj.database.adapters import ais_user_subscribed
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import ChatEvent, clean_code_python, clean_json
|
||||||
ChatEvent,
|
from khoj.routers.helpers import construct_chat_history, send_message_to_model_wrapper
|
||||||
clean_code_python,
|
|
||||||
clean_json,
|
|
||||||
construct_chat_history,
|
|
||||||
)
|
|
||||||
from khoj.routers.helpers import send_message_to_model_wrapper
|
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
|
|
|
@ -91,7 +91,6 @@ from khoj.processor.conversation.utils import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
ThreadedGenerator,
|
ThreadedGenerator,
|
||||||
clean_json,
|
clean_json,
|
||||||
construct_chat_history,
|
|
||||||
generate_chatml_messages_with_context,
|
generate_chatml_messages_with_context,
|
||||||
save_to_conversation_log,
|
save_to_conversation_log,
|
||||||
)
|
)
|
||||||
|
@ -104,6 +103,7 @@ from khoj.utils.config import OfflineChatProcessorModel
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
LRU,
|
LRU,
|
||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
|
get_file_type,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_valid_url,
|
is_valid_url,
|
||||||
log_telemetry,
|
log_telemetry,
|
||||||
|
@ -111,7 +111,7 @@ from khoj.utils.helpers import (
|
||||||
timer,
|
timer,
|
||||||
tool_descriptions_for_llm,
|
tool_descriptions_for_llm,
|
||||||
)
|
)
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -167,6 +167,12 @@ async def is_ready_to_chat(user: KhojUser):
|
||||||
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_content(file: UploadFile):
|
||||||
|
file_content = file.file.read()
|
||||||
|
file_type, encoding = get_file_type(file.content_type, file_content)
|
||||||
|
return FileData(name=file.filename, content=file_content, file_type=file_type, encoding=encoding)
|
||||||
|
|
||||||
|
|
||||||
def update_telemetry_state(
|
def update_telemetry_state(
|
||||||
request: Request,
|
request: Request,
|
||||||
telemetry_type: str,
|
telemetry_type: str,
|
||||||
|
@ -248,23 +254,49 @@ async def agenerate_chat_response(*args):
|
||||||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
||||||
|
|
||||||
|
|
||||||
async def gather_attached_files(
|
def gather_raw_attached_files(
|
||||||
user: KhojUser,
|
attached_files: Dict[str, str],
|
||||||
file_filters: List[str],
|
):
|
||||||
) -> str:
|
"""_summary_
|
||||||
|
Gather contextual data from the given (raw) files
|
||||||
"""
|
"""
|
||||||
Gather contextual data from the given files
|
|
||||||
"""
|
if len(attached_files) == 0:
|
||||||
if len(file_filters) == 0:
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
file_objects = await FileObjectAdapters.async_get_file_objects_by_names(user, file_filters)
|
contextual_data = " ".join(
|
||||||
|
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()]
|
||||||
|
)
|
||||||
|
return f"I have attached the following files:\n\n{contextual_data}"
|
||||||
|
|
||||||
if len(file_objects) == 0:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects])
|
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||||
return contextual_data
|
chat_history = ""
|
||||||
|
for chat in conversation_history.get("chat", [])[-n:]:
|
||||||
|
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
||||||
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
|
|
||||||
|
if chat["intent"].get("inferred-queries"):
|
||||||
|
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
|
||||||
|
|
||||||
|
chat_history += f"{agent_name}: {chat['message']}\n\n"
|
||||||
|
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||||
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
|
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||||
|
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
||||||
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
|
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
||||||
|
elif chat["by"] == "you":
|
||||||
|
raw_attached_files = chat.get("attachedFiles")
|
||||||
|
if raw_attached_files:
|
||||||
|
attached_files: Dict[str, str] = {}
|
||||||
|
for file in raw_attached_files:
|
||||||
|
attached_files[file["name"]] = file["content"]
|
||||||
|
|
||||||
|
attached_file_context = gather_raw_attached_files(attached_files)
|
||||||
|
chat_history += f"User: {attached_file_context}\n"
|
||||||
|
|
||||||
|
return chat_history
|
||||||
|
|
||||||
|
|
||||||
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
||||||
|
@ -1179,6 +1211,7 @@ def generate_chat_response(
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
train_of_thought: List[Any] = [],
|
train_of_thought: List[Any] = [],
|
||||||
attached_files: str = None,
|
attached_files: str = None,
|
||||||
|
raw_attached_files: List[FileAttachment] = None,
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
chat_response = None
|
chat_response = None
|
||||||
|
@ -1204,6 +1237,7 @@ def generate_chat_response(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
|
raw_attached_files=raw_attached_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||||
|
@ -1299,6 +1333,7 @@ def generate_chat_response(
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
query_images=query_images,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
attached_files=attached_files,
|
attached_files=attached_files,
|
||||||
|
@ -1313,23 +1348,6 @@ def generate_chat_response(
|
||||||
return chat_response, metadata
|
return chat_response, metadata
|
||||||
|
|
||||||
|
|
||||||
class ChatRequestBody(BaseModel):
|
|
||||||
q: str
|
|
||||||
n: Optional[int] = 7
|
|
||||||
d: Optional[float] = None
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
title: Optional[str] = None
|
|
||||||
conversation_id: Optional[str] = None
|
|
||||||
turn_id: Optional[str] = None
|
|
||||||
city: Optional[str] = None
|
|
||||||
region: Optional[str] = None
|
|
||||||
country: Optional[str] = None
|
|
||||||
country_code: Optional[str] = None
|
|
||||||
timezone: Optional[str] = None
|
|
||||||
images: Optional[list[str]] = None
|
|
||||||
create_new: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteMessageRequestBody(BaseModel):
|
class DeleteMessageRequestBody(BaseModel):
|
||||||
conversation_id: str
|
conversation_id: str
|
||||||
turn_id: str
|
turn_id: str
|
||||||
|
|
|
@ -20,7 +20,6 @@ from khoj.routers.api import extract_references_and_questions
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
construct_chat_history,
|
construct_chat_history,
|
||||||
extract_relevant_info,
|
|
||||||
generate_summary_from_files,
|
generate_summary_from_files,
|
||||||
send_message_to_model_wrapper,
|
send_message_to_model_wrapper,
|
||||||
)
|
)
|
||||||
|
@ -187,6 +186,7 @@ async def execute_information_collection(
|
||||||
query_images,
|
query_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
attached_files=attached_files,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
|
|
@ -138,6 +138,38 @@ class SearchResponse(ConfigBase):
|
||||||
corpus_id: str
|
corpus_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class FileData(BaseModel):
|
||||||
|
name: str
|
||||||
|
content: bytes
|
||||||
|
file_type: str
|
||||||
|
encoding: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileAttachment(BaseModel):
|
||||||
|
name: str
|
||||||
|
content: str
|
||||||
|
file_type: str
|
||||||
|
size: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequestBody(BaseModel):
|
||||||
|
q: str
|
||||||
|
n: Optional[int] = 7
|
||||||
|
d: Optional[float] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
title: Optional[str] = None
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
turn_id: Optional[str] = None
|
||||||
|
city: Optional[str] = None
|
||||||
|
region: Optional[str] = None
|
||||||
|
country: Optional[str] = None
|
||||||
|
country_code: Optional[str] = None
|
||||||
|
timezone: Optional[str] = None
|
||||||
|
images: Optional[list[str]] = None
|
||||||
|
files: Optional[list[FileAttachment]] = None
|
||||||
|
create_new: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class Entry:
|
class Entry:
|
||||||
raw: str
|
raw: str
|
||||||
compiled: str
|
compiled: str
|
||||||
|
|
Loading…
Add table
Reference in a new issue