Handle multiple images shared in query to chat API

Previously Khoj could respond to a single shared image at a time.

This changes updates the chat API to accept multiple images shared by
the user and send it to the appropriate chat actors including the
openai response generation chat actor for getting an image aware
response
This commit is contained in:
Debanjum Singh Solanky 2024-10-17 23:05:43 -07:00
parent d55cba8627
commit e2abc1a257
8 changed files with 90 additions and 81 deletions

View file

@ -6,7 +6,7 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from khoj.database.models import Agent, KhojUser from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import ( from khoj.processor.conversation.google.utils import (
format_messages_for_gemini, format_messages_for_gemini,
@ -187,6 +187,7 @@ def converse_gemini(
model_name=model, model_name=model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.GOOGLE,
) )
messages, system_prompt = format_messages_for_gemini(messages, system_prompt) messages, system_prompt = format_messages_for_gemini(messages, system_prompt)

View file

@ -30,7 +30,7 @@ def extract_questions(
api_base_url=None, api_base_url=None,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None, user: KhojUser = None,
uploaded_image_url: Optional[str] = None, query_images: Optional[list[str]] = None,
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
): ):
@ -74,7 +74,7 @@ def extract_questions(
prompt = construct_structured_message( prompt = construct_structured_message(
message=prompt, message=prompt,
image_url=uploaded_image_url, images=query_images,
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModelOptions.ModelType.OPENAI,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
) )
@ -135,7 +135,7 @@ def converse(
location_data: LocationData = None, location_data: LocationData = None,
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
image_url: Optional[str] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
): ):
""" """
@ -191,7 +191,7 @@ def converse(
model_name=model, model_name=model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
uploaded_image_url=image_url, query_images=query_images,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModelOptions.ModelType.OPENAI,
) )

View file

@ -109,7 +109,7 @@ def save_to_conversation_log(
client_application: ClientApplication = None, client_application: ClientApplication = None,
conversation_id: str = None, conversation_id: str = None,
automation_id: str = None, automation_id: str = None,
uploaded_image_url: str = None, query_images: List[str] = None,
): ):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log( updated_conversation = message_to_log(
@ -117,7 +117,7 @@ def save_to_conversation_log(
chat_response=chat_response, chat_response=chat_response,
user_message_metadata={ user_message_metadata={
"created": user_message_time, "created": user_message_time,
"uploadedImageData": uploaded_image_url, "images": query_images,
}, },
khoj_message_metadata={ khoj_message_metadata={
"context": compiled_references, "context": compiled_references,
@ -145,10 +145,18 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
) )
# Format user and system messages to chatml format def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool):
def construct_structured_message(message, image_url, model_type, vision_enabled): """
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI: Format messages into appropriate multimedia format for supported chat model types
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}] """
if not images or not vision_enabled:
return message
if model_type == ChatModelOptions.ModelType.OPENAI:
return [
{"type": "text", "text": message},
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
]
return message return message
@ -160,7 +168,7 @@ def generate_chatml_messages_with_context(
loaded_model: Optional[Llama] = None, loaded_model: Optional[Llama] = None,
max_prompt_size=None, max_prompt_size=None,
tokenizer_name=None, tokenizer_name=None,
uploaded_image_url=None, query_images=None,
vision_enabled=False, vision_enabled=False,
model_type="", model_type="",
): ):
@ -183,9 +191,7 @@ def generate_chatml_messages_with_context(
message_content = chat["message"] + message_notes message_content = chat["message"] + message_notes
message_content = construct_structured_message( message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled)
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
)
reconstructed_message = ChatMessage(content=message_content, role=role) reconstructed_message = ChatMessage(content=message_content, role=role)
@ -198,7 +204,7 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(user_message): if not is_none_or_empty(user_message):
messages.append( messages.append(
ChatMessage( ChatMessage(
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled), content=construct_structured_message(user_message, query_images, model_type, vision_enabled),
role="user", role="user",
) )
) )
@ -222,7 +228,6 @@ def truncate_messages(
tokenizer_name=None, tokenizer_name=None,
) -> list[ChatMessage]: ) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model""" """Truncate messages to fit within max prompt size supported by model"""
default_tokenizer = "gpt-4o" default_tokenizer = "gpt-4o"
try: try:
@ -252,6 +257,7 @@ def truncate_messages(
system_message = messages.pop(idx) system_message = messages.pop(idx)
break break
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
system_message_tokens = ( system_message_tokens = (
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0 len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
) )

View file

@ -26,7 +26,7 @@ async def text_to_image(
references: List[Dict[str, Any]], references: List[Dict[str, Any]],
online_results: Dict[str, Any], online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None, query_images: Optional[List[str]] = None,
agent: Agent = None, agent: Agent = None,
): ):
status_code = 200 status_code = 200
@ -65,7 +65,7 @@ async def text_to_image(
note_references=references, note_references=references,
online_results=online_results, online_results=online_results,
model_type=text_to_image_config.model_type, model_type=text_to_image_config.model_type,
uploaded_image_url=uploaded_image_url, query_images=query_images,
user=user, user=user,
agent=agent, agent=agent,
) )

View file

@ -62,7 +62,7 @@ async def search_online(
user: KhojUser, user: KhojUser,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [], custom_filters: List[str] = [],
uploaded_image_url: str = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
): ):
query += " ".join(custom_filters) query += " ".join(custom_filters)
@ -73,7 +73,7 @@ async def search_online(
# Breakdown the query into subqueries to get the correct answer # Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries( subqueries = await generate_online_subqueries(
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent query, conversation_history, location, user, query_images=query_images, agent=agent
) )
response_dict = {} response_dict = {}
@ -151,7 +151,7 @@ async def read_webpages(
location: LocationData, location: LocationData,
user: KhojUser, user: KhojUser,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
): ):
"Infer web pages to read from the query and extract relevant information from them" "Infer web pages to read from the query and extract relevant information from them"
@ -159,7 +159,7 @@ async def read_webpages(
if send_status_func: if send_status_func:
async for event in send_status_func(f"**Inferring web pages to read**"): async for event in send_status_func(f"**Inferring web pages to read**"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url) urls = await infer_webpage_urls(query, conversation_history, location, user, query_images)
logger.info(f"Reading web pages at: {urls}") logger.info(f"Reading web pages at: {urls}")
if send_status_func: if send_status_func:

View file

@ -340,7 +340,7 @@ async def extract_references_and_questions(
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
location_data: LocationData = None, location_data: LocationData = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None, query_images: Optional[List[str]] = None,
agent: Agent = None, agent: Agent = None,
): ):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
@ -431,7 +431,7 @@ async def extract_references_and_questions(
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
user=user, user=user,
uploaded_image_url=uploaded_image_url, query_images=query_images,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
) )

View file

@ -535,7 +535,7 @@ class ChatRequestBody(BaseModel):
country: Optional[str] = None country: Optional[str] = None
country_code: Optional[str] = None country_code: Optional[str] = None
timezone: Optional[str] = None timezone: Optional[str] = None
image: Optional[str] = None images: Optional[list[str]] = None
create_new: Optional[bool] = False create_new: Optional[bool] = False
@ -564,9 +564,9 @@ async def chat(
country = body.country or get_country_name_from_timezone(body.timezone) country = body.country or get_country_name_from_timezone(body.timezone)
country_code = body.country_code or get_country_code_from_timezone(body.timezone) country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone timezone = body.timezone
image = body.image raw_images = body.images
async def event_generator(q: str, image: str): async def event_generator(q: str, images: list[str]):
start_time = time.perf_counter() start_time = time.perf_counter()
ttft = None ttft = None
chat_metadata: dict = {} chat_metadata: dict = {}
@ -576,16 +576,16 @@ async def chat(
q = unquote(q) q = unquote(q)
nonlocal conversation_id nonlocal conversation_id
uploaded_image_url = None uploaded_images: list[str] = []
if image: if images:
decoded_string = unquote(image) for image in images:
base64_data = decoded_string.split(",", 1)[1] decoded_string = unquote(image)
image_bytes = base64.b64decode(base64_data) base64_data = decoded_string.split(",", 1)[1]
webp_image_bytes = convert_image_to_webp(image_bytes) image_bytes = base64.b64decode(base64_data)
try: webp_image_bytes = convert_image_to_webp(image_bytes)
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id) uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
except: if uploaded_image:
uploaded_image_url = None uploaded_images.append(uploaded_image)
async def send_event(event_type: ChatEvent, data: str | dict): async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft nonlocal connection_alive, ttft
@ -692,7 +692,7 @@ async def chat(
meta_log, meta_log,
is_automated_task, is_automated_task,
user=user, user=user,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
agent=agent, agent=agent,
) )
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
@ -701,7 +701,7 @@ async def chat(
): ):
yield result yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent) mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result yield result
if mode not in conversation_commands: if mode not in conversation_commands:
@ -764,7 +764,7 @@ async def chat(
q, q,
contextual_data, contextual_data,
conversation_history=meta_log, conversation_history=meta_log,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
user=user, user=user,
agent=agent, agent=agent,
) )
@ -785,7 +785,7 @@ async def chat(
intent_type="summarize", intent_type="summarize",
client_application=request.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
) )
return return
@ -828,7 +828,7 @@ async def chat(
conversation_id=conversation_id, conversation_id=conversation_id,
inferred_queries=[query_to_run], inferred_queries=[query_to_run],
automation_id=automation.id, automation_id=automation.id,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
) )
async for result in send_llm_response(llm_response): async for result in send_llm_response(llm_response):
yield result yield result
@ -848,7 +848,7 @@ async def chat(
conversation_commands, conversation_commands,
location, location,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
agent=agent, agent=agent,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -892,7 +892,7 @@ async def chat(
user, user,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
custom_filters, custom_filters,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
agent=agent, agent=agent,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -916,7 +916,7 @@ async def chat(
location, location,
user, user,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
agent=agent, agent=agent,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -966,20 +966,20 @@ async def chat(
references=compiled_references, references=compiled_references,
online_results=online_results, online_results=online_results,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
agent=agent, agent=agent,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
else: else:
image, status_code, improved_image_prompt, intent_type = result generated_image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200: if generated_image is None or status_code != 200:
content_obj = { content_obj = {
"content-type": "application/json", "content-type": "application/json",
"intentType": intent_type, "intentType": intent_type,
"detail": improved_image_prompt, "detail": improved_image_prompt,
"image": image, "image": None,
} }
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj)):
yield result yield result
@ -987,7 +987,7 @@ async def chat(
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
image, generated_image,
user, user,
meta_log, meta_log,
user_message_time, user_message_time,
@ -997,12 +997,12 @@ async def chat(
conversation_id=conversation_id, conversation_id=conversation_id,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
) )
content_obj = { content_obj = {
"intentType": intent_type, "intentType": intent_type,
"inferredQueries": [improved_image_prompt], "inferredQueries": [improved_image_prompt],
"image": image, "image": generated_image,
} }
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj)):
yield result yield result
@ -1024,7 +1024,7 @@ async def chat(
conversation_id, conversation_id,
location, location,
user_name, user_name,
uploaded_image_url, uploaded_images,
) )
# Send Response # Send Response
@ -1050,9 +1050,9 @@ async def chat(
## Stream Text Response ## Stream Text Response
if stream: if stream:
return StreamingResponse(event_generator(q, image=image), media_type="text/plain") return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
## Non-Streaming Text Response ## Non-Streaming Text Response
else: else:
response_iterator = event_generator(q, image=image) response_iterator = event_generator(q, images=raw_images)
response_data = await read_chat_stream(response_iterator) response_data = await read_chat_stream(response_iterator)
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)

View file

@ -290,7 +290,7 @@ async def aget_relevant_information_sources(
conversation_history: dict, conversation_history: dict,
is_task: bool, is_task: bool,
user: KhojUser, user: KhojUser,
uploaded_image_url: str = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
): ):
""" """
@ -309,8 +309,8 @@ async def aget_relevant_information_sources(
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
if uploaded_image_url: if query_images:
query = f"[placeholder for user attached image]\n{query}" query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
personality_context = ( personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@ -367,7 +367,7 @@ async def aget_relevant_output_modes(
conversation_history: dict, conversation_history: dict,
is_task: bool = False, is_task: bool = False,
user: KhojUser = None, user: KhojUser = None,
uploaded_image_url: str = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
): ):
""" """
@ -389,8 +389,8 @@ async def aget_relevant_output_modes(
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
if uploaded_image_url: if query_images:
query = f"[placeholder for user attached image]\n{query}" query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
personality_context = ( personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@ -433,7 +433,7 @@ async def infer_webpage_urls(
conversation_history: dict, conversation_history: dict,
location_data: LocationData, location_data: LocationData,
user: KhojUser, user: KhojUser,
uploaded_image_url: str = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
) -> List[str]: ) -> List[str]:
""" """
@ -459,7 +459,7 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger): with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user online_queries_prompt, query_images=query_images, response_type="json_object", user=user
) )
# Validate that the response is a non-empty, JSON-serializable list of URLs # Validate that the response is a non-empty, JSON-serializable list of URLs
@ -479,7 +479,7 @@ async def generate_online_subqueries(
conversation_history: dict, conversation_history: dict,
location_data: LocationData, location_data: LocationData,
user: KhojUser, user: KhojUser,
uploaded_image_url: str = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
) -> List[str]: ) -> List[str]:
""" """
@ -505,7 +505,7 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger): with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user online_queries_prompt, query_images=query_images, response_type="json_object", user=user
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
@ -524,7 +524,7 @@ async def generate_online_subqueries(
async def schedule_query( async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
) -> Tuple[str, ...]: ) -> Tuple[str, ...]:
""" """
Schedule the date, time to run the query. Assume the server timezone is UTC. Schedule the date, time to run the query. Assume the server timezone is UTC.
@ -537,7 +537,7 @@ async def schedule_query(
) )
raw_response = await send_message_to_model_wrapper( raw_response = await send_message_to_model_wrapper(
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user crontime_prompt, query_images=query_images, response_type="json_object", user=user
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
@ -583,7 +583,7 @@ async def extract_relevant_summary(
q: str, q: str,
corpus: str, corpus: str,
conversation_history: dict, conversation_history: dict,
uploaded_image_url: str = None, query_images: List[str] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
) -> Union[str, None]: ) -> Union[str, None]:
@ -612,7 +612,7 @@ async def extract_relevant_summary(
extract_relevant_information, extract_relevant_information,
prompts.system_prompt_extract_relevant_summary, prompts.system_prompt_extract_relevant_summary,
user=user, user=user,
uploaded_image_url=uploaded_image_url, query_images=query_images,
) )
return response.strip() return response.strip()
@ -624,7 +624,7 @@ async def generate_better_image_prompt(
note_references: List[Dict[str, Any]], note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None, online_results: Optional[dict] = None,
model_type: Optional[str] = None, model_type: Optional[str] = None,
uploaded_image_url: Optional[str] = None, query_images: Optional[List[str]] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
) -> str: ) -> str:
@ -676,7 +676,7 @@ async def generate_better_image_prompt(
) )
with timer("Chat actor: Generate contextual image prompt", logger): with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user) response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1] response = response[1:-1]
@ -689,11 +689,11 @@ async def send_message_to_model_wrapper(
system_message: str = "", system_message: str = "",
response_type: str = "text", response_type: str = "text",
user: KhojUser = None, user: KhojUser = None,
uploaded_image_url: str = None, query_images: List[str] = None,
): ):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled vision_available = conversation_config.vision_enabled
if not vision_available and uploaded_image_url: if not vision_available and query_images:
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config: if vision_enabled_config:
conversation_config = vision_enabled_config conversation_config = vision_enabled_config
@ -746,7 +746,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
vision_enabled=vision_available, vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url, query_images=query_images,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
) )
@ -766,7 +766,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
vision_enabled=vision_available, vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url, query_images=query_images,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
) )
@ -784,7 +784,8 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
vision_enabled=vision_available, vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url, query_images=query_images,
model_type=conversation_config.model_type,
) )
return gemini_send_message_to_model( return gemini_send_message_to_model(
@ -875,6 +876,7 @@ def send_message_to_model_wrapper_sync(
model_name=chat_model, model_name=chat_model,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type,
) )
return gemini_send_message_to_model( return gemini_send_message_to_model(
@ -900,7 +902,7 @@ def generate_chat_response(
conversation_id: str = None, conversation_id: str = None,
location_data: LocationData = None, location_data: LocationData = None,
user_name: Optional[str] = None, user_name: Optional[str] = None,
uploaded_image_url: Optional[str] = None, query_images: Optional[List[str]] = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables # Initialize Variables
chat_response = None chat_response = None
@ -919,12 +921,12 @@ def generate_chat_response(
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
client_application=client_application, client_application=client_application,
conversation_id=conversation_id, conversation_id=conversation_id,
uploaded_image_url=uploaded_image_url, query_images=query_images,
) )
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
vision_available = conversation_config.vision_enabled vision_available = conversation_config.vision_enabled
if not vision_available and uploaded_image_url: if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config() vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
if vision_enabled_config: if vision_enabled_config:
conversation_config = vision_enabled_config conversation_config = vision_enabled_config
@ -955,7 +957,7 @@ def generate_chat_response(
chat_response = converse( chat_response = converse(
compiled_references, compiled_references,
q, q,
image_url=uploaded_image_url, query_images=query_images,
online_results=online_results, online_results=online_results,
conversation_log=meta_log, conversation_log=meta_log,
model=chat_model, model=chat_model,