Add attached files to latest structured message in chat ml format

This commit is contained in:
sabaimran 2024-11-09 19:17:00 -08:00
parent 835fa80a4b
commit 92b6b3ef7b
4 changed files with 23 additions and 21 deletions

View file

@ -83,13 +83,11 @@ def extract_questions_anthropic(
images=query_images, images=query_images,
model_type=ChatModelOptions.ModelType.ANTHROPIC, model_type=ChatModelOptions.ModelType.ANTHROPIC,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
attached_file_context=attached_files,
) )
messages = [] messages = []
if attached_files:
messages.append(ChatMessage(content=attached_files, role="user"))
messages.append(ChatMessage(content=prompt, role="user")) messages.append(ChatMessage(content=prompt, role="user"))
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)

View file

@ -84,13 +84,11 @@ def extract_questions_gemini(
images=query_images, images=query_images,
model_type=ChatModelOptions.ModelType.GOOGLE, model_type=ChatModelOptions.ModelType.GOOGLE,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
attached_file_context=attached_files,
) )
messages = [] messages = []
if attached_files:
messages.append(ChatMessage(content=attached_files, role="user"))
messages.append(ChatMessage(content=prompt, role="user")) messages.append(ChatMessage(content=prompt, role="user"))
messages.append(ChatMessage(content=system_prompt, role="system")) messages.append(ChatMessage(content=system_prompt, role="system"))

View file

@ -80,13 +80,10 @@ def extract_questions(
images=query_images, images=query_images,
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModelOptions.ModelType.OPENAI,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
attached_file_context=attached_files,
) )
messages = [] messages = []
if attached_files:
messages.append(ChatMessage(content=attached_files, role="user"))
messages.append(ChatMessage(content=prompt, role="user")) messages.append(ChatMessage(content=prompt, role="user"))
response = send_message_to_model( response = send_message_to_model(

View file

@ -271,23 +271,31 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
) )
def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool): def construct_structured_message(
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str
):
""" """
Format messages into appropriate multimedia format for supported chat model types Format messages into appropriate multimedia format for supported chat model types
""" """
if not images or not vision_enabled: if not images or not vision_enabled:
return message return message
constructed_messages = [
{"type": "text", "text": message},
]
if not is_none_or_empty(attached_file_context):
constructed_messages.append({"type": "text", "text": attached_file_context})
if model_type in [ if model_type in [
ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE, ChatModelOptions.ModelType.GOOGLE,
ChatModelOptions.ModelType.ANTHROPIC, ChatModelOptions.ModelType.ANTHROPIC,
]: ]:
return [ for image in images:
{"type": "text", "text": message}, constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
] return constructed_messages
return message
def gather_raw_attached_files( def gather_raw_attached_files(
@ -362,7 +370,9 @@ def generate_chatml_messages_with_context(
chatml_messages.insert(0, reconstructed_context_message) chatml_messages.insert(0, reconstructed_context_message)
role = "user" if chat["by"] == "you" else "assistant" role = "user" if chat["by"] == "you" else "assistant"
message_content = construct_structured_message(chat["message"], chat.get("images"), model_type, vision_enabled) message_content = construct_structured_message(
chat["message"], chat.get("images"), model_type, vision_enabled, attached_file_context=attached_files
)
reconstructed_message = ChatMessage(content=message_content, role=role) reconstructed_message = ChatMessage(content=message_content, role=role)
chatml_messages.insert(0, reconstructed_message) chatml_messages.insert(0, reconstructed_message)
@ -374,16 +384,15 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(user_message): if not is_none_or_empty(user_message):
messages.append( messages.append(
ChatMessage( ChatMessage(
content=construct_structured_message(user_message, query_images, model_type, vision_enabled), content=construct_structured_message(
user_message, query_images, model_type, vision_enabled, attached_files
),
role="user", role="user",
) )
) )
if not is_none_or_empty(context_message): if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user")) messages.append(ChatMessage(content=context_message, role="user"))
if not is_none_or_empty(attached_files):
messages.append(ChatMessage(content=attached_files, role="user"))
if len(chatml_messages) > 0: if len(chatml_messages) > 0:
messages += chatml_messages messages += chatml_messages