diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 7359b3eb..e8848806 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -6,7 +6,7 @@ from typing import Dict, Optional from langchain.schema import ChatMessage -from khoj.database.models import Agent, KhojUser +from khoj.database.models import Agent, ChatModelOptions, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.google.utils import ( format_messages_for_gemini, @@ -187,6 +187,7 @@ def converse_gemini( model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, + model_type=ChatModelOptions.ModelType.GOOGLE, ) messages, system_prompt = format_messages_for_gemini(messages, system_prompt) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index ad02b10e..4a656fac 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -30,7 +30,7 @@ def extract_questions( api_base_url=None, location_data: LocationData = None, user: KhojUser = None, - uploaded_image_url: Optional[str] = None, + query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, ): @@ -74,7 +74,7 @@ def extract_questions( prompt = construct_structured_message( message=prompt, - image_url=uploaded_image_url, + images=query_images, model_type=ChatModelOptions.ModelType.OPENAI, vision_enabled=vision_enabled, ) @@ -135,7 +135,7 @@ def converse( location_data: LocationData = None, user_name: str = None, agent: Agent = None, - image_url: Optional[str] = None, + query_images: Optional[list[str]] = None, vision_available: bool = False, ): """ @@ -191,7 +191,7 @@ def converse( model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, - uploaded_image_url=image_url, + query_images=query_images, vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.OPENAI, ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e841c484..8d799745 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -109,7 +109,7 @@ def save_to_conversation_log( client_application: ClientApplication = None, conversation_id: str = None, automation_id: str = None, - uploaded_image_url: str = None, + query_images: List[str] = None, ): user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") updated_conversation = message_to_log( @@ -117,7 +117,7 @@ def save_to_conversation_log( chat_response=chat_response, user_message_metadata={ "created": user_message_time, - "uploadedImageData": uploaded_image_url, + "images": query_images, }, khoj_message_metadata={ "context": compiled_references, @@ -145,10 +145,18 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response} ) -# Format user and system messages to chatml format -def construct_structured_message(message, image_url, model_type, vision_enabled): - if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI: - return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}] +def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool): + """ + Format messages into appropriate multimedia format for supported chat model types + """ + if not images or not vision_enabled: + return message + + if model_type == ChatModelOptions.ModelType.OPENAI: + return [ + {"type": "text", "text": message}, + *[{"type": "image_url", "image_url": {"url": image}} for image in images], + ] return message @@ -160,7 +168,7 @@ def generate_chatml_messages_with_context( loaded_model: Optional[Llama] = None, max_prompt_size=None, tokenizer_name=None, - uploaded_image_url=None, + query_images=None, vision_enabled=False, model_type="", ): @@ -183,9 +191,7 @@ def generate_chatml_messages_with_context( message_content = chat["message"] + message_notes - message_content = construct_structured_message( - message_content, chat.get("uploadedImageData"), model_type, vision_enabled - ) + message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled) reconstructed_message = ChatMessage(content=message_content, role=role) @@ -198,7 +204,7 @@ def generate_chatml_messages_with_context( if not is_none_or_empty(user_message): messages.append( ChatMessage( - content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled), + content=construct_structured_message(user_message, query_images, model_type, vision_enabled), role="user", ) ) @@ -222,7 +228,6 @@ def truncate_messages( tokenizer_name=None, ) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" - default_tokenizer = "gpt-4o" try: @@ -252,6 +257,7 @@ def truncate_messages( system_message = messages.pop(idx) break + # TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string system_message_tokens = ( len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0 ) diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 59073731..ee39bdc5 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -26,7 +26,7 @@ async def text_to_image( references: List[Dict[str, Any]], online_results: Dict[str, Any], send_status_func: Optional[Callable] = None, - uploaded_image_url: Optional[str] = None, + query_images: Optional[List[str]] = None, agent: Agent = None, ): status_code = 200 @@ -65,7 +65,7 @@ async def text_to_image( note_references=references, online_results=online_results, model_type=text_to_image_config.model_type, - uploaded_image_url=uploaded_image_url, + query_images=query_images, user=user, agent=agent, ) diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 70972eac..fdf1ba9f 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -62,7 +62,7 @@ async def search_online( user: KhojUser, send_status_func: Optional[Callable] = None, custom_filters: List[str] = [], - uploaded_image_url: str = None, + query_images: List[str] = None, agent: Agent = None, ): query += " ".join(custom_filters) @@ -73,7 +73,7 @@ async def search_online( # Breakdown the query into subqueries to get the correct answer subqueries = await generate_online_subqueries( - query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent + query, conversation_history, location, user, query_images=query_images, agent=agent ) response_dict = {} @@ -151,7 +151,7 @@ async def read_webpages( location: LocationData, user: KhojUser, send_status_func: Optional[Callable] = None, - uploaded_image_url: str = None, + query_images: List[str] = None, agent: Agent = None, ): "Infer web pages to read from the query and extract relevant information from them" @@ -159,7 +159,7 @@ async def read_webpages( if send_status_func: async for event in send_status_func(f"**Inferring web pages to read**"): yield {ChatEvent.STATUS: event} - urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url) + urls = await infer_webpage_urls(query, conversation_history, location, user, query_images) logger.info(f"Reading web pages at: {urls}") if send_status_func: diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 59948b47..075c8c47 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -340,7 +340,7 @@ async def extract_references_and_questions( conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], location_data: LocationData = None, send_status_func: Optional[Callable] = None, - uploaded_image_url: Optional[str] = None, + query_images: Optional[List[str]] = None, agent: Agent = None, ): user = request.user.object if request.user.is_authenticated else None @@ -431,7 +431,7 @@ async def extract_references_and_questions( conversation_log=meta_log, location_data=location_data, user=user, - uploaded_image_url=uploaded_image_url, + query_images=query_images, vision_enabled=vision_enabled, personality_context=personality_context, ) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index d57b5530..ee84c554 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -535,7 +535,7 @@ class ChatRequestBody(BaseModel): country: Optional[str] = None country_code: Optional[str] = None timezone: Optional[str] = None - image: Optional[str] = None + images: Optional[list[str]] = None create_new: Optional[bool] = False @@ -564,9 +564,9 @@ async def chat( country = body.country or get_country_name_from_timezone(body.timezone) country_code = body.country_code or get_country_code_from_timezone(body.timezone) timezone = body.timezone - image = body.image + raw_images = body.images - async def event_generator(q: str, image: str): + async def event_generator(q: str, images: list[str]): start_time = time.perf_counter() ttft = None chat_metadata: dict = {} @@ -576,16 +576,16 @@ async def chat( q = unquote(q) nonlocal conversation_id - uploaded_image_url = None - if image: - decoded_string = unquote(image) - base64_data = decoded_string.split(",", 1)[1] - image_bytes = base64.b64decode(base64_data) - webp_image_bytes = convert_image_to_webp(image_bytes) - try: - uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id) - except: - uploaded_image_url = None + uploaded_images: list[str] = [] + if images: + for image in images: + decoded_string = unquote(image) + base64_data = decoded_string.split(",", 1)[1] + image_bytes = base64.b64decode(base64_data) + webp_image_bytes = convert_image_to_webp(image_bytes) + uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id) + if uploaded_image: + uploaded_images.append(uploaded_image) async def send_event(event_type: ChatEvent, data: str | dict): nonlocal connection_alive, ttft @@ -692,7 +692,7 @@ async def chat( meta_log, is_automated_task, user=user, - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, agent=agent, ) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) @@ -701,7 +701,7 @@ async def chat( ): yield result - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent) + mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent) async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): yield result if mode not in conversation_commands: @@ -764,7 +764,7 @@ async def chat( q, contextual_data, conversation_history=meta_log, - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, user=user, agent=agent, ) @@ -785,7 +785,7 @@ async def chat( intent_type="summarize", client_application=request.user.client_app, conversation_id=conversation_id, - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, ) return @@ -828,7 +828,7 @@ async def chat( conversation_id=conversation_id, inferred_queries=[query_to_run], automation_id=automation.id, - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, ) async for result in send_llm_response(llm_response): yield result @@ -848,7 +848,7 @@ async def chat( conversation_commands, location, partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, agent=agent, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -892,7 +892,7 @@ async def chat( user, partial(send_event, ChatEvent.STATUS), custom_filters, - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, agent=agent, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -916,7 +916,7 @@ async def chat( location, user, partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, agent=agent, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -966,20 +966,20 @@ async def chat( references=compiled_references, online_results=online_results, send_status_func=partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, agent=agent, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] else: - image, status_code, improved_image_prompt, intent_type = result + generated_image, status_code, improved_image_prompt, intent_type = result - if image is None or status_code != 200: + if generated_image is None or status_code != 200: content_obj = { "content-type": "application/json", "intentType": intent_type, "detail": improved_image_prompt, - "image": image, + "image": None, } async for result in send_llm_response(json.dumps(content_obj)): yield result @@ -987,7 +987,7 @@ async def chat( await sync_to_async(save_to_conversation_log)( q, - image, + generated_image, user, meta_log, user_message_time, @@ -997,12 +997,12 @@ async def chat( conversation_id=conversation_id, compiled_references=compiled_references, online_results=online_results, - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, ) content_obj = { "intentType": intent_type, "inferredQueries": [improved_image_prompt], - "image": image, + "image": generated_image, } async for result in send_llm_response(json.dumps(content_obj)): yield result @@ -1024,7 +1024,7 @@ async def chat( conversation_id, location, user_name, - uploaded_image_url, + uploaded_images, ) # Send Response @@ -1050,9 +1050,9 @@ async def chat( ## Stream Text Response if stream: - return StreamingResponse(event_generator(q, image=image), media_type="text/plain") + return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain") ## Non-Streaming Text Response else: - response_iterator = event_generator(q, image=image) + response_iterator = event_generator(q, images=raw_images) response_data = await read_chat_stream(response_iterator) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 12616e36..7ed9c72d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -290,7 +290,7 @@ async def aget_relevant_information_sources( conversation_history: dict, is_task: bool, user: KhojUser, - uploaded_image_url: str = None, + query_images: List[str] = None, agent: Agent = None, ): """ @@ -309,8 +309,8 @@ async def aget_relevant_information_sources( chat_history = construct_chat_history(conversation_history) - if uploaded_image_url: - query = f"[placeholder for user attached image]\n{query}" + if query_images: + query = f"[placeholder for {len(query_images)} user attached images]\n{query}" personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" @@ -367,7 +367,7 @@ async def aget_relevant_output_modes( conversation_history: dict, is_task: bool = False, user: KhojUser = None, - uploaded_image_url: str = None, + query_images: List[str] = None, agent: Agent = None, ): """ @@ -389,8 +389,8 @@ async def aget_relevant_output_modes( chat_history = construct_chat_history(conversation_history) - if uploaded_image_url: - query = f"[placeholder for user attached image]\n{query}" + if query_images: + query = f"[placeholder for {len(query_images)} user attached images]\n{query}" personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" @@ -433,7 +433,7 @@ async def infer_webpage_urls( conversation_history: dict, location_data: LocationData, user: KhojUser, - uploaded_image_url: str = None, + query_images: List[str] = None, agent: Agent = None, ) -> List[str]: """ @@ -459,7 +459,7 @@ async def infer_webpage_urls( with timer("Chat actor: Infer webpage urls to read", logger): response = await send_message_to_model_wrapper( - online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user + online_queries_prompt, query_images=query_images, response_type="json_object", user=user ) # Validate that the response is a non-empty, JSON-serializable list of URLs @@ -479,7 +479,7 @@ async def generate_online_subqueries( conversation_history: dict, location_data: LocationData, user: KhojUser, - uploaded_image_url: str = None, + query_images: List[str] = None, agent: Agent = None, ) -> List[str]: """ @@ -505,7 +505,7 @@ async def generate_online_subqueries( with timer("Chat actor: Generate online search subqueries", logger): response = await send_message_to_model_wrapper( - online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user + online_queries_prompt, query_images=query_images, response_type="json_object", user=user ) # Validate that the response is a non-empty, JSON-serializable list @@ -524,7 +524,7 @@ async def generate_online_subqueries( async def schedule_query( - q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None + q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None ) -> Tuple[str, ...]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. @@ -537,7 +537,7 @@ async def schedule_query( ) raw_response = await send_message_to_model_wrapper( - crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user + crontime_prompt, query_images=query_images, response_type="json_object", user=user ) # Validate that the response is a non-empty, JSON-serializable list @@ -583,7 +583,7 @@ async def extract_relevant_summary( q: str, corpus: str, conversation_history: dict, - uploaded_image_url: str = None, + query_images: List[str] = None, user: KhojUser = None, agent: Agent = None, ) -> Union[str, None]: @@ -612,7 +612,7 @@ async def extract_relevant_summary( extract_relevant_information, prompts.system_prompt_extract_relevant_summary, user=user, - uploaded_image_url=uploaded_image_url, + query_images=query_images, ) return response.strip() @@ -624,7 +624,7 @@ async def generate_better_image_prompt( note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, model_type: Optional[str] = None, - uploaded_image_url: Optional[str] = None, + query_images: Optional[List[str]] = None, user: KhojUser = None, agent: Agent = None, ) -> str: @@ -676,7 +676,7 @@ async def generate_better_image_prompt( ) with timer("Chat actor: Generate contextual image prompt", logger): - response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user) + response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): response = response[1:-1] @@ -689,11 +689,11 @@ async def send_message_to_model_wrapper( system_message: str = "", response_type: str = "text", user: KhojUser = None, - uploaded_image_url: str = None, + query_images: List[str] = None, ): conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) vision_available = conversation_config.vision_enabled - if not vision_available and uploaded_image_url: + if not vision_available and query_images: vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() if vision_enabled_config: conversation_config = vision_enabled_config @@ -746,7 +746,7 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, tokenizer_name=tokenizer, vision_enabled=vision_available, - uploaded_image_url=uploaded_image_url, + query_images=query_images, model_type=conversation_config.model_type, ) @@ -766,7 +766,7 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, tokenizer_name=tokenizer, vision_enabled=vision_available, - uploaded_image_url=uploaded_image_url, + query_images=query_images, model_type=conversation_config.model_type, ) @@ -784,7 +784,8 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, tokenizer_name=tokenizer, vision_enabled=vision_available, - uploaded_image_url=uploaded_image_url, + query_images=query_images, + model_type=conversation_config.model_type, ) return gemini_send_message_to_model( @@ -875,6 +876,7 @@ def send_message_to_model_wrapper_sync( model_name=chat_model, max_prompt_size=max_tokens, vision_enabled=vision_available, + model_type=conversation_config.model_type, ) return gemini_send_message_to_model( @@ -900,7 +902,7 @@ def generate_chat_response( conversation_id: str = None, location_data: LocationData = None, user_name: Optional[str] = None, - uploaded_image_url: Optional[str] = None, + query_images: Optional[List[str]] = None, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables chat_response = None @@ -919,12 +921,12 @@ def generate_chat_response( inferred_queries=inferred_queries, client_application=client_application, conversation_id=conversation_id, - uploaded_image_url=uploaded_image_url, + query_images=query_images, ) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) vision_available = conversation_config.vision_enabled - if not vision_available and uploaded_image_url: + if not vision_available and query_images: vision_enabled_config = ConversationAdapters.get_vision_enabled_config() if vision_enabled_config: conversation_config = vision_enabled_config @@ -955,7 +957,7 @@ def generate_chat_response( chat_response = converse( compiled_references, q, - image_url=uploaded_image_url, + query_images=query_images, online_results=online_results, conversation_log=meta_log, model=chat_model,