Rename attached_files to query_files. Update relevant backend and client-side code.

This commit is contained in:
sabaimran 2024-11-11 11:21:26 -08:00
parent 47937d5148
commit 2bb2ff27a4
16 changed files with 159 additions and 146 deletions

View file

@ -83,19 +83,21 @@ function ChatBodyData(props: ChatBodyDataProps) {
}
const storedUploadedFiles = localStorage.getItem("uploadedFiles");
const parsedFiles = storedUploadedFiles ? JSON.parse(storedUploadedFiles) : [];
const uploadedFiles: AttachedFileText[] = [];
for (const file of parsedFiles) {
uploadedFiles.push({
name: file.name,
file_type: file.file_type,
content: file.content,
size: file.size,
});
if (storedUploadedFiles) {
const parsedFiles = storedUploadedFiles ? JSON.parse(storedUploadedFiles) : [];
const uploadedFiles: AttachedFileText[] = [];
for (const file of parsedFiles) {
uploadedFiles.push({
name: file.name,
file_type: file.file_type,
content: file.content,
size: file.size,
});
}
localStorage.removeItem("uploadedFiles");
props.setUploadedFiles(uploadedFiles);
}
localStorage.removeItem("uploadedFiles");
props.setUploadedFiles(uploadedFiles);
}, [setQueryToProcess, props.setImages, conversationId]);
useEffect(() => {
@ -212,7 +214,7 @@ export default function Chat() {
timestamp: new Date().toISOString(),
rawQuery: queryToProcess || "",
images: images,
attachedFiles: uploadedFiles,
queryFiles: uploadedFiles,
};
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
setProcessQuerySignal(true);

View file

@ -373,7 +373,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
images: message.images,
conversationId: props.conversationId,
turnId: messageTurnId,
attachedFiles: message.attachedFiles,
queryFiles: message.queryFiles,
}}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}

View file

@ -161,7 +161,7 @@ export interface SingleChatMessage {
images?: string[];
conversationId: string;
turnId?: string;
attachedFiles?: AttachedFileText[];
queryFiles?: AttachedFileText[];
}
export interface StreamMessage {
@ -178,7 +178,7 @@ export interface StreamMessage {
intentType?: string;
inferredQueries?: string[];
turnId?: string;
attachedFiles?: AttachedFileText[];
queryFiles?: AttachedFileText[];
}
export interface ChatHistoryData {
@ -708,16 +708,21 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
onMouseLeave={(event) => setIsHovering(false)}
onMouseEnter={(event) => setIsHovering(true)}
>
{props.chatMessage.attachedFiles && props.chatMessage.attachedFiles.length > 0 && (
<div className="flex flex-wrap flex-col m-2">
{props.chatMessage.attachedFiles.map((file, index) => (
{props.chatMessage.queryFiles && props.chatMessage.queryFiles.length > 0 && (
<div className="flex flex-wrap flex-col m-2 max-w-full">
{props.chatMessage.queryFiles.map((file, index) => (
<Dialog key={index}>
<DialogTrigger>
<div className="flex items-center space-x-2 cursor-pointer bg-gray-500 bg-opacity-25 rounded-lg m-1 p-2 w-full">
{getIconFromFilename(file.file_type)}
<span className="truncate">{file.name}</span>
<DialogTrigger asChild>
<div
className="flex items-center space-x-2 cursor-pointer bg-gray-500 bg-opacity-25 rounded-lg m-1 p-2 w-full
"
>
<div className="flex-shrink-0">
{getIconFromFilename(file.file_type)}
</div>
<span className="truncate flex-1 min-w-0">{file.name}</span>
{file.size && (
<span className="text-gray-400">
<span className="text-gray-400 flex-shrink-0">
({convertBytesToText(file.size)})
</span>
)}

View file

@ -103,9 +103,17 @@ class PdfToEntries(TextToEntries):
pdf_entries_per_file = loader.load()
# Convert the loaded entries into the desired format
pdf_entry_by_pages = [page.page_content for page in pdf_entries_per_file]
pdf_entry_by_pages = [PdfToEntries.clean_text(page.page_content) for page in pdf_entries_per_file]
except Exception as e:
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
return pdf_entry_by_pages
@staticmethod
def clean_text(text: str) -> str:
# Remove null bytes
text = text.replace("\x00", "")
# Replace invalid Unicode
text = text.encode("utf-8", errors="ignore").decode("utf-8")
return text

View file

@ -36,7 +36,7 @@ def extract_questions_anthropic(
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
"""
@ -83,7 +83,7 @@ def extract_questions_anthropic(
images=query_images,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
vision_enabled=vision_enabled,
attached_file_context=attached_files,
attached_file_context=query_files,
)
messages = []
@ -152,7 +152,7 @@ def converse_anthropic(
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
"""
@ -210,7 +210,7 @@ def converse_anthropic(
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
attached_files=attached_files,
query_files=query_files,
)
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)

View file

@ -37,7 +37,7 @@ def extract_questions_gemini(
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
"""
@ -84,7 +84,7 @@ def extract_questions_gemini(
images=query_images,
model_type=ChatModelOptions.ModelType.GOOGLE,
vision_enabled=vision_enabled,
attached_file_context=attached_files,
attached_file_context=query_files,
)
messages = []
@ -162,7 +162,7 @@ def converse_gemini(
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
attached_files: str = None,
query_files: str = None,
tracer={},
):
"""
@ -221,7 +221,7 @@ def converse_gemini(
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE,
attached_files=attached_files,
query_files=query_files,
)
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)

View file

@ -37,7 +37,7 @@ def extract_questions_offline(
max_prompt_size: int = None,
temperature: float = 0.7,
personality_context: Optional[str] = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
) -> List[str]:
"""
@ -88,7 +88,7 @@ def extract_questions_offline(
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
model_type=ChatModelOptions.ModelType.OFFLINE,
attached_files=attached_files,
query_files=query_files,
)
state.chat_lock.acquire()
@ -154,7 +154,7 @@ def converse_offline(
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
@ -219,7 +219,7 @@ def converse_offline(
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.OFFLINE,
attached_files=attached_files,
query_files=query_files,
)
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})

View file

@ -34,7 +34,7 @@ def extract_questions(
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
"""
@ -80,7 +80,7 @@ def extract_questions(
images=query_images,
model_type=ChatModelOptions.ModelType.OPENAI,
vision_enabled=vision_enabled,
attached_file_context=attached_files,
attached_file_context=query_files,
)
messages = []
@ -151,7 +151,7 @@ def converse(
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
"""
@ -210,7 +210,7 @@ def converse(
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
attached_files=attached_files,
query_files=query_files,
)
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}")

View file

@ -157,14 +157,14 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
elif chat["by"] == "you":
raw_attached_files = chat.get("attachedFiles")
if raw_attached_files:
attached_files: Dict[str, str] = {}
for file in raw_attached_files:
attached_files[file["name"]] = file["content"]
raw_query_files = chat.get("queryFiles")
if raw_query_files:
query_files: Dict[str, str] = {}
for file in raw_query_files:
query_files[file["name"]] = file["content"]
attached_file_context = gather_raw_attached_files(attached_files)
chat_history += f"User: {attached_file_context}\n"
query_file_context = gather_raw_query_files(query_files)
chat_history += f"User: {query_file_context}\n"
return chat_history
@ -254,7 +254,7 @@ def save_to_conversation_log(
conversation_id: str = None,
automation_id: str = None,
query_images: List[str] = None,
raw_attached_files: List[FileAttachment] = [],
raw_query_files: List[FileAttachment] = [],
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
@ -267,7 +267,7 @@ def save_to_conversation_log(
"created": user_message_time,
"images": query_images,
"turnId": turn_id,
"attachedFiles": [file.model_dump(mode="json") for file in raw_attached_files],
"queryFiles": [file.model_dump(mode="json") for file in raw_query_files],
},
khoj_message_metadata={
"context": compiled_references,
@ -330,18 +330,18 @@ def construct_structured_message(
return message
def gather_raw_attached_files(
attached_files: Dict[str, str],
def gather_raw_query_files(
query_files: Dict[str, str],
):
"""
Gather contextual data from the given (raw) files
"""
if len(attached_files) == 0:
if len(query_files) == 0:
return ""
contextual_data = " ".join(
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()]
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in query_files.items()]
)
return f"I have attached the following files:\n\n{contextual_data}"
@ -358,7 +358,7 @@ def generate_chatml_messages_with_context(
vision_enabled=False,
model_type="",
context_message="",
attached_files: str = None,
query_files: str = None,
):
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
# Set max prompt size from user config or based on pre-configured for model and machine specs
@ -389,13 +389,13 @@ def generate_chatml_messages_with_context(
)
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
if chat.get("attachedFiles"):
raw_attached_files = chat.get("attachedFiles")
attached_files_dict = dict()
for file in raw_attached_files:
attached_files_dict[file["name"]] = file["content"]
if chat.get("queryFiles"):
raw_query_files = chat.get("queryFiles")
query_files_dict = dict()
for file in raw_query_files:
query_files_dict[file["name"]] = file["content"]
message_attached_files = gather_raw_attached_files(attached_files_dict)
message_attached_files = gather_raw_query_files(query_files_dict)
chatml_messages.append(ChatMessage(content=message_attached_files, role="user"))
if not is_none_or_empty(chat.get("onlineContext")):
@ -407,7 +407,7 @@ def generate_chatml_messages_with_context(
role = "user" if chat["by"] == "you" else "assistant"
message_content = construct_structured_message(
chat["message"], chat.get("images"), model_type, vision_enabled, attached_file_context=attached_files
chat["message"], chat.get("images"), model_type, vision_enabled, attached_file_context=query_files
)
reconstructed_message = ChatMessage(content=message_content, role=role)
@ -421,7 +421,7 @@ def generate_chatml_messages_with_context(
messages.append(
ChatMessage(
content=construct_structured_message(
user_message, query_images, model_type, vision_enabled, attached_files
user_message, query_images, model_type, vision_enabled, query_files
),
role="user",
)

View file

@ -28,7 +28,7 @@ async def text_to_image(
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
status_code = 200
@ -70,7 +70,7 @@ async def text_to_image(
query_images=query_images,
user=user,
agent=agent,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)

View file

@ -68,7 +68,7 @@ async def search_online(
query_images: List[str] = None,
previous_subqueries: Set = set(),
agent: Agent = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
query += " ".join(custom_filters)
@ -86,7 +86,7 @@ async def search_online(
query_images=query_images,
agent=agent,
tracer=tracer,
attached_files=attached_files,
query_files=query_files,
)
subqueries = list(new_subqueries - previous_subqueries)
response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}
@ -178,7 +178,7 @@ async def read_webpages(
query_images: List[str] = None,
agent: Agent = None,
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
"Infer web pages to read from the query and extract relevant information from them"
@ -190,7 +190,7 @@ async def read_webpages(
user,
query_images,
agent=agent,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)

View file

@ -36,7 +36,7 @@ async def run_code(
query_images: List[str] = None,
agent: Agent = None,
sandbox_url: str = SANDBOX_URL,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
# Generate Code
@ -54,7 +54,7 @@ async def run_code(
query_images,
agent,
tracer,
attached_files,
query_files,
)
except Exception as e:
raise ValueError(f"Failed to generate code for {query} with error: {e}")
@ -84,7 +84,7 @@ async def generate_python_code(
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
attached_files: str = None,
query_files: str = None,
) -> List[str]:
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
@ -112,7 +112,7 @@ async def generate_python_code(
response_type="json_object",
user=user,
tracer=tracer,
attached_files=attached_files,
query_files=query_files,
)
# Validate that the response is a non-empty, JSON-serializable list

View file

@ -351,7 +351,7 @@ async def extract_references_and_questions(
query_images: Optional[List[str]] = None,
previous_inferred_queries: Set = set(),
agent: Agent = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
user = request.user.object if request.user.is_authenticated else None
@ -426,7 +426,7 @@ async def extract_references_and_questions(
user=user,
max_prompt_size=conversation_config.max_prompt_size,
personality_context=personality_context,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -445,7 +445,7 @@ async def extract_references_and_questions(
query_images=query_images,
vision_enabled=vision_enabled,
personality_context=personality_context,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
@ -461,7 +461,7 @@ async def extract_references_and_questions(
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
@ -478,7 +478,7 @@ async def extract_references_and_questions(
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)

View file

@ -50,7 +50,7 @@ from khoj.routers.helpers import (
aget_relevant_output_modes,
construct_automation_created_message,
create_automation,
gather_raw_attached_files,
gather_raw_query_files,
generate_excalidraw_diagram,
generate_summary_from_files,
get_conversation_command,
@ -602,7 +602,7 @@ async def chat(
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone
raw_images = body.images
raw_attached_files = body.files
raw_query_files = body.files
async def event_generator(q: str, images: list[str]):
start_time = time.perf_counter()
@ -614,7 +614,7 @@ async def chat(
q = unquote(q)
train_of_thought = []
nonlocal conversation_id
nonlocal raw_attached_files
nonlocal raw_query_files
tracer: dict = {
"mid": turn_id,
@ -634,10 +634,10 @@ async def chat(
if uploaded_image:
uploaded_images.append(uploaded_image)
attached_files: Dict[str, str] = {}
if raw_attached_files:
for file in raw_attached_files:
attached_files[file.name] = file.content
query_files: Dict[str, str] = {}
if raw_query_files:
for file in raw_query_files:
query_files[file.name] = file.content
async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft, train_of_thought
@ -750,7 +750,7 @@ async def chat(
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
attached_file_context = gather_raw_attached_files(attached_files)
attached_file_context = gather_raw_query_files(query_files)
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(
@ -760,7 +760,7 @@ async def chat(
user=user,
query_images=uploaded_images,
agent=agent,
attached_files=attached_file_context,
query_files=attached_file_context,
tracer=tracer,
)
@ -806,7 +806,7 @@ async def chat(
user_name=user_name,
location=location,
file_filters=conversation.file_filters if conversation else [],
attached_files=attached_file_context,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(research_result, InformationCollectionIteration):
@ -855,7 +855,7 @@ async def chat(
query_images=uploaded_images,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
attached_files=attached_file_context,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(response, dict) and ChatEvent.STATUS in response:
@ -877,7 +877,7 @@ async def chat(
conversation_id=conversation_id,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_attached_files=raw_attached_files,
raw_query_files=raw_query_files,
tracer=tracer,
)
return
@ -923,7 +923,7 @@ async def chat(
automation_id=automation.id,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_attached_files=raw_attached_files,
raw_query_files=raw_query_files,
tracer=tracer,
)
async for result in send_llm_response(llm_response):
@ -946,7 +946,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
attached_files=attached_file_context,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -992,7 +992,7 @@ async def chat(
custom_filters,
query_images=uploaded_images,
agent=agent,
attached_files=attached_file_context,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -1018,7 +1018,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
attached_files=attached_file_context,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -1059,7 +1059,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
attached_files=attached_file_context,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -1100,7 +1100,7 @@ async def chat(
send_status_func=partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
attached_files=attached_file_context,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -1134,7 +1134,7 @@ async def chat(
query_images=uploaded_images,
train_of_thought=train_of_thought,
attached_file_context=attached_file_context,
raw_attached_files=raw_attached_files,
raw_query_files=raw_query_files,
tracer=tracer,
)
content_obj = {
@ -1164,7 +1164,7 @@ async def chat(
user=user,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
attached_files=attached_file_context,
query_files=attached_file_context,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -1195,7 +1195,7 @@ async def chat(
query_images=uploaded_images,
train_of_thought=train_of_thought,
attached_file_context=attached_file_context,
raw_attached_files=raw_attached_files,
raw_query_files=raw_query_files,
tracer=tracer,
)
@ -1224,7 +1224,7 @@ async def chat(
uploaded_images,
train_of_thought,
attached_file_context,
raw_attached_files,
raw_query_files,
tracer,
)

View file

@ -256,18 +256,18 @@ async def agenerate_chat_response(*args):
return await loop.run_in_executor(executor, generate_chat_response, *args)
def gather_raw_attached_files(
attached_files: Dict[str, str],
def gather_raw_query_files(
query_files: Dict[str, str],
):
"""
Gather contextual data from the given (raw) files
"""
if len(attached_files) == 0:
if len(query_files) == 0:
return ""
contextual_data = " ".join(
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()]
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in query_files.items()]
)
return f"I have attached the following files:\n\n{contextual_data}"
@ -334,7 +334,7 @@ async def aget_relevant_information_sources(
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
"""
@ -372,7 +372,7 @@ async def aget_relevant_information_sources(
relevant_tools_prompt,
response_type="json_object",
user=user,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)
@ -482,7 +482,7 @@ async def infer_webpage_urls(
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
) -> List[str]:
"""
@ -512,7 +512,7 @@ async def infer_webpage_urls(
query_images=query_images,
response_type="json_object",
user=user,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)
@ -538,7 +538,7 @@ async def generate_online_subqueries(
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
) -> Set[str]:
"""
@ -568,7 +568,7 @@ async def generate_online_subqueries(
query_images=query_images,
response_type="json_object",
user=user,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)
@ -691,7 +691,7 @@ async def generate_summary_from_files(
query_images: List[str] = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
try:
@ -701,17 +701,15 @@ async def generate_summary_from_files(
if len(file_names) > 0:
file_objects = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent)
if (file_objects and len(file_objects) == 0 and not attached_files) or (
not file_objects and not attached_files
):
if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files):
response_log = "Sorry, I couldn't find anything to summarize."
yield response_log
return
contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects])
if attached_files:
contextual_data += f"\n\n{attached_files}"
if query_files:
contextual_data += f"\n\n{query_files}"
if not q:
q = "Create a general summary of the file"
@ -754,7 +752,7 @@ async def generate_excalidraw_diagram(
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
if send_status_func:
@ -770,7 +768,7 @@ async def generate_excalidraw_diagram(
query_images=query_images,
user=user,
agent=agent,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)
@ -797,7 +795,7 @@ async def generate_better_diagram_description(
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
) -> str:
"""
@ -839,7 +837,7 @@ async def generate_better_diagram_description(
improve_diagram_description_prompt,
query_images=query_images,
user=user,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)
response = response.strip()
@ -887,7 +885,7 @@ async def generate_better_image_prompt(
query_images: Optional[List[str]] = None,
user: KhojUser = None,
agent: Agent = None,
attached_files: str = "",
query_files: str = "",
tracer: dict = {},
) -> str:
"""
@ -936,7 +934,7 @@ async def generate_better_image_prompt(
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(
image_prompt, query_images=query_images, user=user, attached_files=attached_files, tracer=tracer
image_prompt, query_images=query_images, user=user, query_files=query_files, tracer=tracer
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@ -952,7 +950,7 @@ async def send_message_to_model_wrapper(
user: KhojUser = None,
query_images: List[str] = None,
context: str = "",
attached_files: str = None,
query_files: str = None,
tracer: dict = {},
):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
@ -992,7 +990,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
attached_files=attached_files,
query_files=query_files,
)
return send_message_to_model_offline(
@ -1019,7 +1017,7 @@ async def send_message_to_model_wrapper(
vision_enabled=vision_available,
query_images=query_images,
model_type=conversation_config.model_type,
attached_files=attached_files,
query_files=query_files,
)
return send_message_to_model(
@ -1042,7 +1040,7 @@ async def send_message_to_model_wrapper(
vision_enabled=vision_available,
query_images=query_images,
model_type=conversation_config.model_type,
attached_files=attached_files,
query_files=query_files,
)
return anthropic_send_message_to_model(
@ -1064,7 +1062,7 @@ async def send_message_to_model_wrapper(
vision_enabled=vision_available,
query_images=query_images,
model_type=conversation_config.model_type,
attached_files=attached_files,
query_files=query_files,
)
return gemini_send_message_to_model(
@ -1079,7 +1077,7 @@ def send_message_to_model_wrapper_sync(
system_message: str = "",
response_type: str = "text",
user: KhojUser = None,
attached_files: str = "",
query_files: str = "",
tracer: dict = {},
):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
@ -1104,7 +1102,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
attached_files=attached_files,
query_files=query_files,
)
return send_message_to_model_offline(
@ -1126,7 +1124,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
attached_files=attached_files,
query_files=query_files,
)
openai_response = send_message_to_model(
@ -1148,7 +1146,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
attached_files=attached_files,
query_files=query_files,
)
return anthropic_send_message_to_model(
@ -1168,7 +1166,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
attached_files=attached_files,
query_files=query_files,
)
return gemini_send_message_to_model(
@ -1199,8 +1197,8 @@ def generate_chat_response(
meta_research: str = "",
query_images: Optional[List[str]] = None,
train_of_thought: List[Any] = [],
attached_files: str = None,
raw_attached_files: List[FileAttachment] = None,
query_files: str = None,
raw_query_files: List[FileAttachment] = None,
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
@ -1223,7 +1221,7 @@ def generate_chat_response(
conversation_id=conversation_id,
query_images=query_images,
train_of_thought=train_of_thought,
raw_attached_files=raw_attached_files,
raw_query_files=raw_query_files,
tracer=tracer,
)
@ -1258,7 +1256,7 @@ def generate_chat_response(
location_data=location_data,
user_name=user_name,
agent=agent,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)
@ -1284,7 +1282,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
vision_available=vision_available,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)
@ -1307,7 +1305,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
vision_available=vision_available,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
@ -1329,7 +1327,7 @@ def generate_chat_response(
agent=agent,
query_images=query_images,
vision_available=vision_available,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)

View file

@ -46,7 +46,7 @@ async def apick_next_tool(
max_iterations: int = 5,
send_status_func: Optional[Callable] = None,
tracer: dict = {},
attached_files: str = None,
query_files: str = None,
):
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
@ -92,7 +92,7 @@ async def apick_next_tool(
response_type="json_object",
user=user,
query_images=query_images,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
)
except Exception as e:
@ -152,7 +152,7 @@ async def execute_information_collection(
location: LocationData = None,
file_filters: List[str] = [],
tracer: dict = {},
attached_files: str = None,
query_files: str = None,
):
current_iteration = 0
MAX_ITERATIONS = 5
@ -176,7 +176,7 @@ async def execute_information_collection(
MAX_ITERATIONS,
send_status_func,
tracer=tracer,
attached_files=attached_files,
query_files=query_files,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -207,7 +207,7 @@ async def execute_information_collection(
previous_inferred_queries=previous_inferred_queries,
agent=agent,
tracer=tracer,
attached_files=attached_files,
query_files=query_files,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -269,7 +269,7 @@ async def execute_information_collection(
query_images=query_images,
agent=agent,
tracer=tracer,
attached_files=attached_files,
query_files=query_files,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -300,7 +300,7 @@ async def execute_information_collection(
send_status_func,
query_images=query_images,
agent=agent,
attached_files=attached_files,
query_files=query_files,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -326,7 +326,7 @@ async def execute_information_collection(
query_images=query_images,
agent=agent,
send_status_func=send_status_func,
attached_files=attached_files,
query_files=query_files,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]