mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Handle multiple images shared in query to chat API
Previously Khoj could respond to a single shared image at a time. This changes updates the chat API to accept multiple images shared by the user and send it to the appropriate chat actors including the openai response generation chat actor for getting an image aware response
This commit is contained in:
parent
d55cba8627
commit
e2abc1a257
8 changed files with 90 additions and 81 deletions
|
@ -6,7 +6,7 @@ from typing import Dict, Optional
|
|||
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
format_messages_for_gemini,
|
||||
|
@ -187,6 +187,7 @@ def converse_gemini(
|
|||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
)
|
||||
|
||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
|
|
|
@ -30,7 +30,7 @@ def extract_questions(
|
|||
api_base_url=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
):
|
||||
|
@ -74,7 +74,7 @@ def extract_questions(
|
|||
|
||||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
image_url=uploaded_image_url,
|
||||
images=query_images,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
vision_enabled=vision_enabled,
|
||||
)
|
||||
|
@ -135,7 +135,7 @@ def converse(
|
|||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
image_url: Optional[str] = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
):
|
||||
"""
|
||||
|
@ -191,7 +191,7 @@ def converse(
|
|||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
uploaded_image_url=image_url,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
)
|
||||
|
|
|
@ -109,7 +109,7 @@ def save_to_conversation_log(
|
|||
client_application: ClientApplication = None,
|
||||
conversation_id: str = None,
|
||||
automation_id: str = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
updated_conversation = message_to_log(
|
||||
|
@ -117,7 +117,7 @@ def save_to_conversation_log(
|
|||
chat_response=chat_response,
|
||||
user_message_metadata={
|
||||
"created": user_message_time,
|
||||
"uploadedImageData": uploaded_image_url,
|
||||
"images": query_images,
|
||||
},
|
||||
khoj_message_metadata={
|
||||
"context": compiled_references,
|
||||
|
@ -145,10 +145,18 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
|||
)
|
||||
|
||||
|
||||
# Format user and system messages to chatml format
|
||||
def construct_structured_message(message, image_url, model_type, vision_enabled):
|
||||
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
|
||||
def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool):
|
||||
"""
|
||||
Format messages into appropriate multimedia format for supported chat model types
|
||||
"""
|
||||
if not images or not vision_enabled:
|
||||
return message
|
||||
|
||||
if model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
return [
|
||||
{"type": "text", "text": message},
|
||||
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
|
||||
]
|
||||
return message
|
||||
|
||||
|
||||
|
@ -160,7 +168,7 @@ def generate_chatml_messages_with_context(
|
|||
loaded_model: Optional[Llama] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
uploaded_image_url=None,
|
||||
query_images=None,
|
||||
vision_enabled=False,
|
||||
model_type="",
|
||||
):
|
||||
|
@ -183,9 +191,7 @@ def generate_chatml_messages_with_context(
|
|||
|
||||
message_content = chat["message"] + message_notes
|
||||
|
||||
message_content = construct_structured_message(
|
||||
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
|
||||
)
|
||||
message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled)
|
||||
|
||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||
|
||||
|
@ -198,7 +204,7 @@ def generate_chatml_messages_with_context(
|
|||
if not is_none_or_empty(user_message):
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
|
||||
content=construct_structured_message(user_message, query_images, model_type, vision_enabled),
|
||||
role="user",
|
||||
)
|
||||
)
|
||||
|
@ -222,7 +228,6 @@ def truncate_messages(
|
|||
tokenizer_name=None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
|
||||
default_tokenizer = "gpt-4o"
|
||||
|
||||
try:
|
||||
|
@ -252,6 +257,7 @@ def truncate_messages(
|
|||
system_message = messages.pop(idx)
|
||||
break
|
||||
|
||||
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
|
||||
system_message_tokens = (
|
||||
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
||||
)
|
||||
|
|
|
@ -26,7 +26,7 @@ async def text_to_image(
|
|||
references: List[Dict[str, Any]],
|
||||
online_results: Dict[str, Any],
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
status_code = 200
|
||||
|
@ -65,7 +65,7 @@ async def text_to_image(
|
|||
note_references=references,
|
||||
online_results=online_results,
|
||||
model_type=text_to_image_config.model_type,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
)
|
||||
|
|
|
@ -62,7 +62,7 @@ async def search_online(
|
|||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
custom_filters: List[str] = [],
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
query += " ".join(custom_filters)
|
||||
|
@ -73,7 +73,7 @@ async def search_online(
|
|||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(
|
||||
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
|
||||
query, conversation_history, location, user, query_images=query_images, agent=agent
|
||||
)
|
||||
response_dict = {}
|
||||
|
||||
|
@ -151,7 +151,7 @@ async def read_webpages(
|
|||
location: LocationData,
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
|
@ -159,7 +159,7 @@ async def read_webpages(
|
|||
if send_status_func:
|
||||
async for event in send_status_func(f"**Inferring web pages to read**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url)
|
||||
urls = await infer_webpage_urls(query, conversation_history, location, user, query_images)
|
||||
|
||||
logger.info(f"Reading web pages at: {urls}")
|
||||
if send_status_func:
|
||||
|
|
|
@ -340,7 +340,7 @@ async def extract_references_and_questions(
|
|||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||
location_data: LocationData = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
@ -431,7 +431,7 @@ async def extract_references_and_questions(
|
|||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
|
|
@ -535,7 +535,7 @@ class ChatRequestBody(BaseModel):
|
|||
country: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
image: Optional[str] = None
|
||||
images: Optional[list[str]] = None
|
||||
create_new: Optional[bool] = False
|
||||
|
||||
|
||||
|
@ -564,9 +564,9 @@ async def chat(
|
|||
country = body.country or get_country_name_from_timezone(body.timezone)
|
||||
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||
timezone = body.timezone
|
||||
image = body.image
|
||||
raw_images = body.images
|
||||
|
||||
async def event_generator(q: str, image: str):
|
||||
async def event_generator(q: str, images: list[str]):
|
||||
start_time = time.perf_counter()
|
||||
ttft = None
|
||||
chat_metadata: dict = {}
|
||||
|
@ -576,16 +576,16 @@ async def chat(
|
|||
q = unquote(q)
|
||||
nonlocal conversation_id
|
||||
|
||||
uploaded_image_url = None
|
||||
if image:
|
||||
decoded_string = unquote(image)
|
||||
base64_data = decoded_string.split(",", 1)[1]
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||
try:
|
||||
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||
except:
|
||||
uploaded_image_url = None
|
||||
uploaded_images: list[str] = []
|
||||
if images:
|
||||
for image in images:
|
||||
decoded_string = unquote(image)
|
||||
base64_data = decoded_string.split(",", 1)[1]
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||
uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||
if uploaded_image:
|
||||
uploaded_images.append(uploaded_image)
|
||||
|
||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||
nonlocal connection_alive, ttft
|
||||
|
@ -692,7 +692,7 @@ async def chat(
|
|||
meta_log,
|
||||
is_automated_task,
|
||||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
)
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
|
@ -701,7 +701,7 @@ async def chat(
|
|||
):
|
||||
yield result
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||
yield result
|
||||
if mode not in conversation_commands:
|
||||
|
@ -764,7 +764,7 @@ async def chat(
|
|||
q,
|
||||
contextual_data,
|
||||
conversation_history=meta_log,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
)
|
||||
|
@ -785,7 +785,7 @@ async def chat(
|
|||
intent_type="summarize",
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -828,7 +828,7 @@ async def chat(
|
|||
conversation_id=conversation_id,
|
||||
inferred_queries=[query_to_run],
|
||||
automation_id=automation.id,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
)
|
||||
async for result in send_llm_response(llm_response):
|
||||
yield result
|
||||
|
@ -848,7 +848,7 @@ async def chat(
|
|||
conversation_commands,
|
||||
location,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
@ -892,7 +892,7 @@ async def chat(
|
|||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
custom_filters,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
@ -916,7 +916,7 @@ async def chat(
|
|||
location,
|
||||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
@ -966,20 +966,20 @@ async def chat(
|
|||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
image, status_code, improved_image_prompt, intent_type = result
|
||||
generated_image, status_code, improved_image_prompt, intent_type = result
|
||||
|
||||
if image is None or status_code != 200:
|
||||
if generated_image is None or status_code != 200:
|
||||
content_obj = {
|
||||
"content-type": "application/json",
|
||||
"intentType": intent_type,
|
||||
"detail": improved_image_prompt,
|
||||
"image": image,
|
||||
"image": None,
|
||||
}
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
yield result
|
||||
|
@ -987,7 +987,7 @@ async def chat(
|
|||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
image,
|
||||
generated_image,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
|
@ -997,12 +997,12 @@ async def chat(
|
|||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
)
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
"inferredQueries": [improved_image_prompt],
|
||||
"image": image,
|
||||
"image": generated_image,
|
||||
}
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
yield result
|
||||
|
@ -1024,7 +1024,7 @@ async def chat(
|
|||
conversation_id,
|
||||
location,
|
||||
user_name,
|
||||
uploaded_image_url,
|
||||
uploaded_images,
|
||||
)
|
||||
|
||||
# Send Response
|
||||
|
@ -1050,9 +1050,9 @@ async def chat(
|
|||
|
||||
## Stream Text Response
|
||||
if stream:
|
||||
return StreamingResponse(event_generator(q, image=image), media_type="text/plain")
|
||||
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
|
||||
## Non-Streaming Text Response
|
||||
else:
|
||||
response_iterator = event_generator(q, image=image)
|
||||
response_iterator = event_generator(q, images=raw_images)
|
||||
response_data = await read_chat_stream(response_iterator)
|
||||
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
||||
|
|
|
@ -290,7 +290,7 @@ async def aget_relevant_information_sources(
|
|||
conversation_history: dict,
|
||||
is_task: bool,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"""
|
||||
|
@ -309,8 +309,8 @@ async def aget_relevant_information_sources(
|
|||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
if uploaded_image_url:
|
||||
query = f"[placeholder for user attached image]\n{query}"
|
||||
if query_images:
|
||||
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
|
@ -367,7 +367,7 @@ async def aget_relevant_output_modes(
|
|||
conversation_history: dict,
|
||||
is_task: bool = False,
|
||||
user: KhojUser = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"""
|
||||
|
@ -389,8 +389,8 @@ async def aget_relevant_output_modes(
|
|||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
if uploaded_image_url:
|
||||
query = f"[placeholder for user attached image]\n{query}"
|
||||
if query_images:
|
||||
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
|
@ -433,7 +433,7 @@ async def infer_webpage_urls(
|
|||
conversation_history: dict,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
|
@ -459,7 +459,7 @@ async def infer_webpage_urls(
|
|||
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||
|
@ -479,7 +479,7 @@ async def generate_online_subqueries(
|
|||
conversation_history: dict,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
|
@ -505,7 +505,7 @@ async def generate_online_subqueries(
|
|||
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
@ -524,7 +524,7 @@ async def generate_online_subqueries(
|
|||
|
||||
|
||||
async def schedule_query(
|
||||
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None
|
||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
|
||||
) -> Tuple[str, ...]:
|
||||
"""
|
||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||
|
@ -537,7 +537,7 @@ async def schedule_query(
|
|||
)
|
||||
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||
crontime_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
@ -583,7 +583,7 @@ async def extract_relevant_summary(
|
|||
q: str,
|
||||
corpus: str,
|
||||
conversation_history: dict,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
) -> Union[str, None]:
|
||||
|
@ -612,7 +612,7 @@ async def extract_relevant_summary(
|
|||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_summary,
|
||||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
@ -624,7 +624,7 @@ async def generate_better_image_prompt(
|
|||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
model_type: Optional[str] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
) -> str:
|
||||
|
@ -676,7 +676,7 @@ async def generate_better_image_prompt(
|
|||
)
|
||||
|
||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user)
|
||||
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
|
@ -689,11 +689,11 @@ async def send_message_to_model_wrapper(
|
|||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
user: KhojUser = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
):
|
||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
if not vision_available and uploaded_image_url:
|
||||
if not vision_available and query_images:
|
||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||
if vision_enabled_config:
|
||||
conversation_config = vision_enabled_config
|
||||
|
@ -746,7 +746,7 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
|
@ -766,7 +766,7 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
|
@ -784,7 +784,8 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
|
@ -875,6 +876,7 @@ def send_message_to_model_wrapper_sync(
|
|||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
|
@ -900,7 +902,7 @@ def generate_chat_response(
|
|||
conversation_id: str = None,
|
||||
location_data: LocationData = None,
|
||||
user_name: Optional[str] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response = None
|
||||
|
@ -919,12 +921,12 @@ def generate_chat_response(
|
|||
inferred_queries=inferred_queries,
|
||||
client_application=client_application,
|
||||
conversation_id=conversation_id,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
)
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
if not vision_available and uploaded_image_url:
|
||||
if not vision_available and query_images:
|
||||
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
||||
if vision_enabled_config:
|
||||
conversation_config = vision_enabled_config
|
||||
|
@ -955,7 +957,7 @@ def generate_chat_response(
|
|||
chat_response = converse(
|
||||
compiled_references,
|
||||
q,
|
||||
image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
online_results=online_results,
|
||||
conversation_log=meta_log,
|
||||
model=chat_model,
|
||||
|
|
Loading…
Reference in a new issue