mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Limit vision_enabled image formatting to OpenAI APIs and send vision to extract_questions query
This commit is contained in:
parent
aa31d041f3
commit
8d40fc0aef
7 changed files with 54 additions and 25 deletions
|
@ -6,7 +6,7 @@ from typing import Dict, Optional
|
|||
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.anthropic.utils import (
|
||||
anthropic_chat_completion_with_backoff,
|
||||
|
@ -188,6 +188,7 @@ def converse_anthropic(
|
|||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||
)
|
||||
|
||||
if len(messages) > 1:
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Any, Iterator, List, Union
|
|||
from langchain.schema import ChatMessage
|
||||
from llama_cpp import Llama
|
||||
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
|
@ -76,7 +76,11 @@ def extract_questions_offline(
|
|||
)
|
||||
|
||||
messages = generate_chatml_messages_with_context(
|
||||
example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
|
||||
example_questions,
|
||||
model_name=model,
|
||||
loaded_model=offline_chat_model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||
)
|
||||
|
||||
state.chat_lock.acquire()
|
||||
|
@ -201,6 +205,7 @@ def converse_offline(
|
|||
loaded_model=offline_chat_model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||
)
|
||||
|
||||
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
|
||||
|
|
|
@ -5,13 +5,16 @@ from typing import Dict, Optional
|
|||
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.openai.utils import (
|
||||
chat_completion_with_backoff,
|
||||
completion_with_backoff,
|
||||
)
|
||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||
from khoj.processor.conversation.utils import (
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
|
@ -24,9 +27,10 @@ def extract_questions(
|
|||
conversation_log={},
|
||||
api_key=None,
|
||||
api_base_url=None,
|
||||
temperature=0.7,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
vision_enabled: bool = False,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -63,17 +67,17 @@ def extract_questions(
|
|||
location=location,
|
||||
username=username,
|
||||
)
|
||||
|
||||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
image_url=uploaded_image_url,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
vision_enabled=vision_enabled,
|
||||
)
|
||||
|
||||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
|
||||
# Get Response from GPT
|
||||
response = completion_with_backoff(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_base_url=api_base_url,
|
||||
model_kwargs={"response_format": {"type": "json_object"}},
|
||||
openai_api_key=api_key,
|
||||
)
|
||||
response = send_message_to_model(messages, api_key, model, response_type="json_object", api_base_url=api_base_url)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
try:
|
||||
|
@ -182,6 +186,7 @@ def converse(
|
|||
tokenizer_name=tokenizer_name,
|
||||
uploaded_image_url=image_url,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
)
|
||||
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
|
||||
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
||||
|
|
|
@ -12,7 +12,7 @@ from llama_cpp.llama import Llama
|
|||
from transformers import AutoTokenizer
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import ClientApplication, KhojUser
|
||||
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||
|
@ -137,6 +137,13 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
|||
)
|
||||
|
||||
|
||||
# Format user and system messages to chatml format
|
||||
def construct_structured_message(message, image_url, model_type, vision_enabled):
|
||||
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
|
||||
return message
|
||||
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
user_message,
|
||||
system_message=None,
|
||||
|
@ -147,6 +154,7 @@ def generate_chatml_messages_with_context(
|
|||
tokenizer_name=None,
|
||||
uploaded_image_url=None,
|
||||
vision_enabled=False,
|
||||
model_type="",
|
||||
):
|
||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||
|
@ -156,12 +164,6 @@ def generate_chatml_messages_with_context(
|
|||
else:
|
||||
max_prompt_size = model_to_prompt_size.get(model_name, 2000)
|
||||
|
||||
# Format user and system messages to chatml format
|
||||
def construct_structured_message(message, image_url):
|
||||
if image_url and vision_enabled:
|
||||
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
|
||||
return message
|
||||
|
||||
# Scale lookback turns proportional to max prompt size supported by model
|
||||
lookback_turns = max_prompt_size // 750
|
||||
|
||||
|
@ -174,7 +176,9 @@ def generate_chatml_messages_with_context(
|
|||
message_content = chat["message"] + message_notes
|
||||
|
||||
if chat.get("uploadedImageData") and vision_enabled:
|
||||
message_content = construct_structured_message(message_content, chat.get("uploadedImageData"))
|
||||
message_content = construct_structured_message(
|
||||
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
|
||||
)
|
||||
|
||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||
|
||||
|
@ -186,7 +190,10 @@ def generate_chatml_messages_with_context(
|
|||
messages = []
|
||||
if not is_none_or_empty(user_message):
|
||||
messages.append(
|
||||
ChatMessage(content=construct_structured_message(user_message, uploaded_image_url), role="user")
|
||||
ChatMessage(
|
||||
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
|
||||
role="user",
|
||||
)
|
||||
)
|
||||
if len(chatml_messages) > 0:
|
||||
messages += chatml_messages
|
||||
|
|
|
@ -331,6 +331,7 @@ async def extract_references_and_questions(
|
|||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||
location_data: LocationData = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
|
@ -370,6 +371,7 @@ async def extract_references_and_questions(
|
|||
with timer("Extracting search queries took", logger):
|
||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
vision_enabled = conversation_config.vision_enabled
|
||||
|
||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
using_offline_chat = True
|
||||
|
@ -403,6 +405,8 @@ async def extract_references_and_questions(
|
|||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
vision_enabled=vision_enabled,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
|
|
@ -807,6 +807,7 @@ async def chat(
|
|||
conversation_commands,
|
||||
location,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
|
|
@ -330,7 +330,7 @@ async def aget_relevant_output_modes(
|
|||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
if uploaded_image_url:
|
||||
query = f"[placeholder for image attached to this message]\n{query}"
|
||||
query = f"<user uploaded content redacted> \n{query}"
|
||||
|
||||
relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
|
||||
query=query,
|
||||
|
@ -622,6 +622,7 @@ async def send_message_to_model_wrapper(
|
|||
tokenizer_name=tokenizer,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return send_message_to_model_offline(
|
||||
|
@ -644,6 +645,7 @@ async def send_message_to_model_wrapper(
|
|||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
|
@ -664,6 +666,7 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return anthropic_send_message_to_model(
|
||||
|
@ -700,6 +703,7 @@ def send_message_to_model_wrapper_sync(
|
|||
model_name=chat_model,
|
||||
loaded_model=loaded_model,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return send_message_to_model_offline(
|
||||
|
@ -717,6 +721,7 @@ def send_message_to_model_wrapper_sync(
|
|||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
|
@ -733,6 +738,7 @@ def send_message_to_model_wrapper_sync(
|
|||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return anthropic_send_message_to_model(
|
||||
|
|
Loading…
Reference in a new issue