Limit vision_enabled image formatting to OpenAI APIs and send vision to extract_questions query

This commit is contained in:
sabaimran 2024-09-10 20:08:14 -07:00
parent aa31d041f3
commit 8d40fc0aef
7 changed files with 54 additions and 25 deletions

View file

@ -6,7 +6,7 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage
from khoj.database.models import Agent, KhojUser
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
@ -188,6 +188,7 @@ def converse_anthropic(
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
)
if len(messages) > 1:

View file

@ -7,7 +7,7 @@ from typing import Any, Iterator, List, Union
from langchain.schema import ChatMessage
from llama_cpp import Llama
from khoj.database.models import Agent, KhojUser
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
@ -76,7 +76,11 @@ def extract_questions_offline(
)
messages = generate_chatml_messages_with_context(
example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
example_questions,
model_name=model,
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
model_type=ChatModelOptions.ModelType.OFFLINE,
)
state.chat_lock.acquire()
@ -201,6 +205,7 @@ def converse_offline(
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.OFFLINE,
)
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})

View file

@ -5,13 +5,16 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage
from khoj.database.models import Agent, KhojUser
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
completion_with_backoff,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.processor.conversation.utils import (
construct_structured_message,
generate_chatml_messages_with_context,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
@ -24,9 +27,10 @@ def extract_questions(
conversation_log={},
api_key=None,
api_base_url=None,
temperature=0.7,
location_data: LocationData = None,
user: KhojUser = None,
uploaded_image_url: Optional[str] = None,
vision_enabled: bool = False,
):
"""
Infer search queries to retrieve relevant notes to answer user query
@ -63,17 +67,17 @@ def extract_questions(
location=location,
username=username,
)
prompt = construct_structured_message(
message=prompt,
image_url=uploaded_image_url,
model_type=ChatModelOptions.ModelType.OPENAI,
vision_enabled=vision_enabled,
)
messages = [ChatMessage(content=prompt, role="user")]
# Get Response from GPT
response = completion_with_backoff(
messages=messages,
model=model,
temperature=temperature,
api_base_url=api_base_url,
model_kwargs={"response_format": {"type": "json_object"}},
openai_api_key=api_key,
)
response = send_message_to_model(messages, api_key, model, response_type="json_object", api_base_url=api_base_url)
# Extract, Clean Message from GPT's Response
try:
@ -182,6 +186,7 @@ def converse(
tokenizer_name=tokenizer_name,
uploaded_image_url=image_url,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
)
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}")

View file

@ -12,7 +12,7 @@ from llama_cpp.llama import Llama
from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ClientApplication, KhojUser
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, merge_dicts
@ -137,6 +137,13 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
)
# Format user and system messages to chatml format
def construct_structured_message(message, image_url, model_type, vision_enabled):
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
return message
def generate_chatml_messages_with_context(
user_message,
system_message=None,
@ -147,6 +154,7 @@ def generate_chatml_messages_with_context(
tokenizer_name=None,
uploaded_image_url=None,
vision_enabled=False,
model_type="",
):
"""Generate messages for ChatGPT with context from previous conversation"""
# Set max prompt size from user config or based on pre-configured for model and machine specs
@ -156,12 +164,6 @@ def generate_chatml_messages_with_context(
else:
max_prompt_size = model_to_prompt_size.get(model_name, 2000)
# Format user and system messages to chatml format
def construct_structured_message(message, image_url):
if image_url and vision_enabled:
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
return message
# Scale lookback turns proportional to max prompt size supported by model
lookback_turns = max_prompt_size // 750
@ -174,7 +176,9 @@ def generate_chatml_messages_with_context(
message_content = chat["message"] + message_notes
if chat.get("uploadedImageData") and vision_enabled:
message_content = construct_structured_message(message_content, chat.get("uploadedImageData"))
message_content = construct_structured_message(
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
)
reconstructed_message = ChatMessage(content=message_content, role=role)
@ -186,7 +190,10 @@ def generate_chatml_messages_with_context(
messages = []
if not is_none_or_empty(user_message):
messages.append(
ChatMessage(content=construct_structured_message(user_message, uploaded_image_url), role="user")
ChatMessage(
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
role="user",
)
)
if len(chatml_messages) > 0:
messages += chatml_messages

View file

@ -331,6 +331,7 @@ async def extract_references_and_questions(
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
):
user = request.user.object if request.user.is_authenticated else None
@ -370,6 +371,7 @@ async def extract_references_and_questions(
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
conversation_config = await ConversationAdapters.aget_default_conversation_config()
vision_enabled = conversation_config.vision_enabled
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
using_offline_chat = True
@ -403,6 +405,8 @@ async def extract_references_and_questions(
conversation_log=meta_log,
location_data=location_data,
user=user,
uploaded_image_url=uploaded_image_url,
vision_enabled=vision_enabled,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key

View file

@ -807,6 +807,7 @@ async def chat(
conversation_commands,
location,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]

View file

@ -330,7 +330,7 @@ async def aget_relevant_output_modes(
chat_history = construct_chat_history(conversation_history)
if uploaded_image_url:
query = f"[placeholder for image attached to this message]\n{query}"
query = f"<user uploaded content redacted> \n{query}"
relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
query=query,
@ -622,6 +622,7 @@ async def send_message_to_model_wrapper(
tokenizer_name=tokenizer,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
)
return send_message_to_model_offline(
@ -644,6 +645,7 @@ async def send_message_to_model_wrapper(
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
model_type=conversation_config.model_type,
)
openai_response = send_message_to_model(
@ -664,6 +666,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
)
return anthropic_send_message_to_model(
@ -700,6 +703,7 @@ def send_message_to_model_wrapper_sync(
model_name=chat_model,
loaded_model=loaded_model,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
)
return send_message_to_model_offline(
@ -717,6 +721,7 @@ def send_message_to_model_wrapper_sync(
system_message=system_message,
model_name=chat_model,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
)
openai_response = send_message_to_model(
@ -733,6 +738,7 @@ def send_message_to_model_wrapper_sync(
model_name=chat_model,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
)
return anthropic_send_message_to_model(