Handle multiple images shared in query to chat API

Previously Khoj could respond to a single shared image at a time.

This changes updates the chat API to accept multiple images shared by
the user and send it to the appropriate chat actors including the
openai response generation chat actor for getting an image aware
response
This commit is contained in:
Debanjum Singh Solanky 2024-10-17 23:05:43 -07:00
parent d55cba8627
commit e2abc1a257
8 changed files with 90 additions and 81 deletions

View file

@ -6,7 +6,7 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage
from khoj.database.models import Agent, KhojUser
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
format_messages_for_gemini,
@ -187,6 +187,7 @@ def converse_gemini(
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.GOOGLE,
)
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)

View file

@ -30,7 +30,7 @@ def extract_questions(
api_base_url=None,
location_data: LocationData = None,
user: KhojUser = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
):
@ -74,7 +74,7 @@ def extract_questions(
prompt = construct_structured_message(
message=prompt,
image_url=uploaded_image_url,
images=query_images,
model_type=ChatModelOptions.ModelType.OPENAI,
vision_enabled=vision_enabled,
)
@ -135,7 +135,7 @@ def converse(
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
image_url: Optional[str] = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
):
"""
@ -191,7 +191,7 @@ def converse(
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
uploaded_image_url=image_url,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
)

View file

@ -109,7 +109,7 @@ def save_to_conversation_log(
client_application: ClientApplication = None,
conversation_id: str = None,
automation_id: str = None,
uploaded_image_url: str = None,
query_images: List[str] = None,
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
@ -117,7 +117,7 @@ def save_to_conversation_log(
chat_response=chat_response,
user_message_metadata={
"created": user_message_time,
"uploadedImageData": uploaded_image_url,
"images": query_images,
},
khoj_message_metadata={
"context": compiled_references,
@ -145,10 +145,18 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
)
# Format user and system messages to chatml format
def construct_structured_message(message, image_url, model_type, vision_enabled):
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool):
"""
Format messages into appropriate multimedia format for supported chat model types
"""
if not images or not vision_enabled:
return message
if model_type == ChatModelOptions.ModelType.OPENAI:
return [
{"type": "text", "text": message},
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
]
return message
@ -160,7 +168,7 @@ def generate_chatml_messages_with_context(
loaded_model: Optional[Llama] = None,
max_prompt_size=None,
tokenizer_name=None,
uploaded_image_url=None,
query_images=None,
vision_enabled=False,
model_type="",
):
@ -183,9 +191,7 @@ def generate_chatml_messages_with_context(
message_content = chat["message"] + message_notes
message_content = construct_structured_message(
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
)
message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled)
reconstructed_message = ChatMessage(content=message_content, role=role)
@ -198,7 +204,7 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(user_message):
messages.append(
ChatMessage(
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
content=construct_structured_message(user_message, query_images, model_type, vision_enabled),
role="user",
)
)
@ -222,7 +228,6 @@ def truncate_messages(
tokenizer_name=None,
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
default_tokenizer = "gpt-4o"
try:
@ -252,6 +257,7 @@ def truncate_messages(
system_message = messages.pop(idx)
break
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
system_message_tokens = (
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
)

View file

@ -26,7 +26,7 @@ async def text_to_image(
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
):
status_code = 200
@ -65,7 +65,7 @@ async def text_to_image(
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
user=user,
agent=agent,
)

View file

@ -62,7 +62,7 @@ async def search_online(
user: KhojUser,
send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [],
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
):
query += " ".join(custom_filters)
@ -73,7 +73,7 @@ async def search_online(
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
query, conversation_history, location, user, query_images=query_images, agent=agent
)
response_dict = {}
@ -151,7 +151,7 @@ async def read_webpages(
location: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
):
"Infer web pages to read from the query and extract relevant information from them"
@ -159,7 +159,7 @@ async def read_webpages(
if send_status_func:
async for event in send_status_func(f"**Inferring web pages to read**"):
yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url)
urls = await infer_webpage_urls(query, conversation_history, location, user, query_images)
logger.info(f"Reading web pages at: {urls}")
if send_status_func:

View file

@ -340,7 +340,7 @@ async def extract_references_and_questions(
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
):
user = request.user.object if request.user.is_authenticated else None
@ -431,7 +431,7 @@ async def extract_references_and_questions(
conversation_log=meta_log,
location_data=location_data,
user=user,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
vision_enabled=vision_enabled,
personality_context=personality_context,
)

View file

@ -535,7 +535,7 @@ class ChatRequestBody(BaseModel):
country: Optional[str] = None
country_code: Optional[str] = None
timezone: Optional[str] = None
image: Optional[str] = None
images: Optional[list[str]] = None
create_new: Optional[bool] = False
@ -564,9 +564,9 @@ async def chat(
country = body.country or get_country_name_from_timezone(body.timezone)
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone
image = body.image
raw_images = body.images
async def event_generator(q: str, image: str):
async def event_generator(q: str, images: list[str]):
start_time = time.perf_counter()
ttft = None
chat_metadata: dict = {}
@ -576,16 +576,16 @@ async def chat(
q = unquote(q)
nonlocal conversation_id
uploaded_image_url = None
if image:
uploaded_images: list[str] = []
if images:
for image in images:
decoded_string = unquote(image)
base64_data = decoded_string.split(",", 1)[1]
image_bytes = base64.b64decode(base64_data)
webp_image_bytes = convert_image_to_webp(image_bytes)
try:
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
except:
uploaded_image_url = None
uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
if uploaded_image:
uploaded_images.append(uploaded_image)
async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft
@ -692,7 +692,7 @@ async def chat(
meta_log,
is_automated_task,
user=user,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
@ -701,7 +701,7 @@ async def chat(
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
@ -764,7 +764,7 @@ async def chat(
q,
contextual_data,
conversation_history=meta_log,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
user=user,
agent=agent,
)
@ -785,7 +785,7 @@ async def chat(
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
)
return
@ -828,7 +828,7 @@ async def chat(
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
)
async for result in send_llm_response(llm_response):
yield result
@ -848,7 +848,7 @@ async def chat(
conversation_commands,
location,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -892,7 +892,7 @@ async def chat(
user,
partial(send_event, ChatEvent.STATUS),
custom_filters,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -916,7 +916,7 @@ async def chat(
location,
user,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -966,20 +966,20 @@ async def chat(
references=compiled_references,
online_results=online_results,
send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
image, status_code, improved_image_prompt, intent_type = result
generated_image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200:
if generated_image is None or status_code != 200:
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"detail": improved_image_prompt,
"image": image,
"image": None,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
@ -987,7 +987,7 @@ async def chat(
await sync_to_async(save_to_conversation_log)(
q,
image,
generated_image,
user,
meta_log,
user_message_time,
@ -997,12 +997,12 @@ async def chat(
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
)
content_obj = {
"intentType": intent_type,
"inferredQueries": [improved_image_prompt],
"image": image,
"image": generated_image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
@ -1024,7 +1024,7 @@ async def chat(
conversation_id,
location,
user_name,
uploaded_image_url,
uploaded_images,
)
# Send Response
@ -1050,9 +1050,9 @@ async def chat(
## Stream Text Response
if stream:
return StreamingResponse(event_generator(q, image=image), media_type="text/plain")
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
## Non-Streaming Text Response
else:
response_iterator = event_generator(q, image=image)
response_iterator = event_generator(q, images=raw_images)
response_data = await read_chat_stream(response_iterator)
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)

View file

@ -290,7 +290,7 @@ async def aget_relevant_information_sources(
conversation_history: dict,
is_task: bool,
user: KhojUser,
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
):
"""
@ -309,8 +309,8 @@ async def aget_relevant_information_sources(
chat_history = construct_chat_history(conversation_history)
if uploaded_image_url:
query = f"[placeholder for user attached image]\n{query}"
if query_images:
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@ -367,7 +367,7 @@ async def aget_relevant_output_modes(
conversation_history: dict,
is_task: bool = False,
user: KhojUser = None,
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
):
"""
@ -389,8 +389,8 @@ async def aget_relevant_output_modes(
chat_history = construct_chat_history(conversation_history)
if uploaded_image_url:
query = f"[placeholder for user attached image]\n{query}"
if query_images:
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@ -433,7 +433,7 @@ async def infer_webpage_urls(
conversation_history: dict,
location_data: LocationData,
user: KhojUser,
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
) -> List[str]:
"""
@ -459,7 +459,7 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list of URLs
@ -479,7 +479,7 @@ async def generate_online_subqueries(
conversation_history: dict,
location_data: LocationData,
user: KhojUser,
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
) -> List[str]:
"""
@ -505,7 +505,7 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list
@ -524,7 +524,7 @@ async def generate_online_subqueries(
async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
) -> Tuple[str, ...]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
@ -537,7 +537,7 @@ async def schedule_query(
)
raw_response = await send_message_to_model_wrapper(
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
crontime_prompt, query_images=query_images, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list
@ -583,7 +583,7 @@ async def extract_relevant_summary(
q: str,
corpus: str,
conversation_history: dict,
uploaded_image_url: str = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
) -> Union[str, None]:
@ -612,7 +612,7 @@ async def extract_relevant_summary(
extract_relevant_information,
prompts.system_prompt_extract_relevant_summary,
user=user,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
)
return response.strip()
@ -624,7 +624,7 @@ async def generate_better_image_prompt(
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
model_type: Optional[str] = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[List[str]] = None,
user: KhojUser = None,
agent: Agent = None,
) -> str:
@ -676,7 +676,7 @@ async def generate_better_image_prompt(
)
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user)
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
@ -689,11 +689,11 @@ async def send_message_to_model_wrapper(
system_message: str = "",
response_type: str = "text",
user: KhojUser = None,
uploaded_image_url: str = None,
query_images: List[str] = None,
):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled
if not vision_available and uploaded_image_url:
if not vision_available and query_images:
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
@ -746,7 +746,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
model_type=conversation_config.model_type,
)
@ -766,7 +766,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
model_type=conversation_config.model_type,
)
@ -784,7 +784,8 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
model_type=conversation_config.model_type,
)
return gemini_send_message_to_model(
@ -875,6 +876,7 @@ def send_message_to_model_wrapper_sync(
model_name=chat_model,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
)
return gemini_send_message_to_model(
@ -900,7 +902,7 @@ def generate_chat_response(
conversation_id: str = None,
location_data: LocationData = None,
user_name: Optional[str] = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[List[str]] = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
chat_response = None
@ -919,12 +921,12 @@ def generate_chat_response(
inferred_queries=inferred_queries,
client_application=client_application,
conversation_id=conversation_id,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
vision_available = conversation_config.vision_enabled
if not vision_available and uploaded_image_url:
if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
@ -955,7 +957,7 @@ def generate_chat_response(
chat_response = converse(
compiled_references,
q,
image_url=uploaded_image_url,
query_images=query_images,
online_results=online_results,
conversation_log=meta_log,
model=chat_model,