mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Support Google's Gemini model series (#902)
* Add functions to chat with Google's gemini model series * Gracefully close thread when there's an exception in the gemini llm thread * Use enums for verifying the chat model option type * Add a migration to add the gemini chat model type to the db model * Fix chat model selection verification and math prompt tuning * Fix extract questions method with gemini. Enforce json response in extract questions. * Add standard stop sequence for Gemini chat response generation --------- Co-authored-by: sabaimran <narmiabas@gmail.com> Co-authored-by: Debanjum Singh Solanky <debanjum@gmail.com>
This commit is contained in:
parent
42b727e926
commit
9570933506
11 changed files with 438 additions and 16 deletions
|
@ -87,7 +87,8 @@ dependencies = [
|
|||
"cron-descriptor == 1.4.3",
|
||||
"django_apscheduler == 0.6.2",
|
||||
"anthropic == 0.26.1",
|
||||
"docx2txt == 0.8"
|
||||
"docx2txt == 0.8",
|
||||
"google-generativeai == 0.7.2"
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
|
|
@ -973,7 +973,7 @@ class ConversationAdapters:
|
|||
if conversation_config is None:
|
||||
conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||
|
||||
if conversation_config.model_type == "offline":
|
||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
chat_model = conversation_config.chat_model
|
||||
max_tokens = conversation_config.max_prompt_size
|
||||
|
@ -982,7 +982,12 @@ class ConversationAdapters:
|
|||
return conversation_config
|
||||
|
||||
if (
|
||||
conversation_config.model_type == "openai" or conversation_config.model_type == "anthropic"
|
||||
conversation_config.model_type
|
||||
in [
|
||||
ChatModelOptions.ModelType.ANTHROPIC,
|
||||
ChatModelOptions.ModelType.OPENAI,
|
||||
ChatModelOptions.ModelType.GOOGLE,
|
||||
]
|
||||
) and conversation_config.openai_config:
|
||||
return conversation_config
|
||||
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# Generated by Django 5.0.8 on 2024-09-12 20:06
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0060_merge_20240905_1828"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="chatmodeloptions",
|
||||
name="model_type",
|
||||
field=models.CharField(
|
||||
choices=[
|
||||
("openai", "Openai"),
|
||||
("offline", "Offline"),
|
||||
("anthropic", "Anthropic"),
|
||||
("google", "Google"),
|
||||
],
|
||||
default="offline",
|
||||
max_length=200,
|
||||
),
|
||||
),
|
||||
]
|
|
@ -87,6 +87,7 @@ class ChatModelOptions(BaseModel):
|
|||
OPENAI = "openai"
|
||||
OFFLINE = "offline"
|
||||
ANTHROPIC = "anthropic"
|
||||
GOOGLE = "google"
|
||||
|
||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
|
|
0
src/khoj/processor/conversation/google/__init__.py
Normal file
0
src/khoj/processor/conversation/google/__init__.py
Normal file
221
src/khoj/processor/conversation/google/gemini_chat.py
Normal file
221
src/khoj/processor/conversation/google/gemini_chat.py
Normal file
|
@ -0,0 +1,221 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional
|
||||
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
gemini_chat_completion_with_backoff,
|
||||
gemini_completion_with_backoff,
|
||||
)
|
||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_questions_gemini(
|
||||
text,
|
||||
model: Optional[str] = "gemini-1.5-flash",
|
||||
conversation_log={},
|
||||
api_key=None,
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = "".join(
|
||||
[
|
||||
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
|
||||
for chat in conversation_log.get("chat", [])[-4:]
|
||||
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type")
|
||||
]
|
||||
)
|
||||
|
||||
# Get dates relative to today for prompt creation
|
||||
today = datetime.today()
|
||||
current_new_year = today.replace(month=1, day=1)
|
||||
last_new_year = current_new_year.replace(year=today.year - 1)
|
||||
|
||||
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
|
||||
current_date=today.strftime("%Y-%m-%d"),
|
||||
day_of_week=today.strftime("%A"),
|
||||
current_month=today.strftime("%Y-%m"),
|
||||
last_new_year=last_new_year.strftime("%Y"),
|
||||
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
|
||||
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
username=username,
|
||||
)
|
||||
|
||||
prompt = prompts.extract_questions_anthropic_user_message.format(
|
||||
chat_history=chat_history,
|
||||
text=text,
|
||||
)
|
||||
|
||||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
|
||||
model_kwargs = {"response_mime_type": "application/json"}
|
||||
|
||||
response = gemini_completion_with_backoff(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
max_tokens=max_tokens,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
try:
|
||||
response = response.strip()
|
||||
match = re.search(r"\{.*?\}", response)
|
||||
if match:
|
||||
response = match.group()
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}")
|
||||
return [text]
|
||||
return response
|
||||
except:
|
||||
logger.warning(f"Gemini returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||
questions = [text]
|
||||
logger.debug(f"Extracted Questions by Gemini: {questions}")
|
||||
return questions
|
||||
|
||||
|
||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
system_prompt = None
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
else:
|
||||
system_prompt = ""
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
|
||||
model_kwargs = {}
|
||||
if response_type == "json_object":
|
||||
model_kwargs["response_mime_type"] = "application/json"
|
||||
|
||||
# Get Response from Gemini
|
||||
return gemini_completion_with_backoff(
|
||||
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
|
||||
)
|
||||
|
||||
|
||||
def converse_gemini(
|
||||
references,
|
||||
user_query,
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
conversation_log={},
|
||||
model: Optional[str] = "gemini-1.5-flash",
|
||||
api_key: Optional[str] = None,
|
||||
temperature: float = 0.2,
|
||||
completion_func=None,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"""
|
||||
Converse with user using Google's Gemini
|
||||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
name=agent.name,
|
||||
bio=agent.personality,
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
else:
|
||||
system_prompt = prompts.personality.format(
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
|
||||
if location_data:
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||
location_prompt = prompts.user_location.format(location=location)
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
|
||||
if user_name:
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
conversation_primer = (
|
||||
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
||||
)
|
||||
if not is_none_or_empty(compiled_references):
|
||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
conversation_log=conversation_log,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
for message in messages.copy():
|
||||
if message.role == "system":
|
||||
system_prompt += message.content
|
||||
messages.remove(message)
|
||||
|
||||
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
||||
logger.debug(f"Conversation Context for Gemini: {truncated_messages}")
|
||||
|
||||
# Get Response from Google AI
|
||||
return gemini_chat_completion_with_backoff(
|
||||
messages=messages,
|
||||
compiled_references=references,
|
||||
online_results=online_results,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
system_prompt=system_prompt,
|
||||
completion_func=completion_func,
|
||||
max_prompt_size=max_prompt_size,
|
||||
)
|
93
src/khoj/processor/conversation/google/utils.py
Normal file
93
src/khoj/processor/conversation/google/utils.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
import logging
|
||||
from threading import Thread
|
||||
|
||||
import google.generativeai as genai
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_MAX_TOKENS_GEMINI = 8192
|
||||
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(2),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def gemini_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
|
||||
) -> str:
|
||||
genai.configure(api_key=api_key)
|
||||
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_GEMINI
|
||||
model_kwargs = model_kwargs or dict()
|
||||
model_kwargs["temperature"] = temperature
|
||||
model_kwargs["max_output_tokens"] = max_tokens
|
||||
model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||
# all messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
# the last message is considered to be the current prompt
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
|
||||
return aggregated_response.text
|
||||
|
||||
|
||||
@retry(
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
stop=stop_after_attempt(2),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def gemini_chat_completion_with_backoff(
|
||||
messages,
|
||||
compiled_references,
|
||||
online_results,
|
||||
model_name,
|
||||
temperature,
|
||||
api_key,
|
||||
system_prompt,
|
||||
max_prompt_size=None,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(
|
||||
target=gemini_llm_thread,
|
||||
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def gemini_llm_thread(
|
||||
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
|
||||
):
|
||||
try:
|
||||
genai.configure(api_key=api_key)
|
||||
max_tokens = max_prompt_size or DEFAULT_MAX_TOKENS_GEMINI
|
||||
model_kwargs = model_kwargs or dict()
|
||||
model_kwargs["temperature"] = temperature
|
||||
model_kwargs["max_output_tokens"] = max_tokens
|
||||
model_kwargs["stop_sequences"] = ["Notes:\n["]
|
||||
model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||
# all messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
# the last message is considered to be the current prompt
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
|
||||
g.send(chunk.text)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
g.close()
|
|
@ -13,8 +13,8 @@ You were created by Khoj Inc. with the following capabilities:
|
|||
- You *CAN* generate images, look-up real-time information from the internet, set reminders and answer questions based on the user's notes.
|
||||
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
||||
- Make sure to use the specific LaTeX math mode delimiters for your response. LaTex math mode specific delimiters as following
|
||||
- inline math mode : `\\(` and `\\)`
|
||||
- display math mode: insert linebreak after opening `$$`, `\\[` and before closing `$$`, `\\]`
|
||||
- inline math mode : \\( and \\)
|
||||
- display math mode: insert linebreak after opening $$, \\[ and before closing $$, \\]
|
||||
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
||||
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
||||
- Provide inline references to quotes from the user's notes or any web pages you refer to in your responses in markdown format. For example, "The farmer had ten sheep. [1](https://example.com)". *ALWAYS CITE YOUR SOURCES AND PROVIDE REFERENCES*. Add them inline to directly support your claim.
|
||||
|
|
|
@ -7,6 +7,7 @@ from collections import defaultdict
|
|||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from markdownify import markdownify
|
||||
|
||||
|
@ -94,7 +95,7 @@ async def search_online(
|
|||
|
||||
# Read, extract relevant info from the retrieved web pages
|
||||
if webpages:
|
||||
webpage_links = [link for link, _, _ in webpages]
|
||||
webpage_links = set([link for link, _, _ in webpages])
|
||||
logger.info(f"Reading web pages at: {list(webpage_links)}")
|
||||
if send_status_func:
|
||||
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
||||
|
|
|
@ -31,6 +31,7 @@ from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOp
|
|||
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||
extract_questions_anthropic,
|
||||
)
|
||||
from khoj.processor.conversation.google.gemini_chat import extract_questions_gemini
|
||||
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
|
||||
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||
|
@ -419,6 +420,18 @@ async def extract_references_and_questions(
|
|||
location_data=location_data,
|
||||
user=user,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
chat_model = conversation_config.chat_model
|
||||
inferred_queries = extract_questions_gemini(
|
||||
defiltered_query,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
max_tokens=conversation_config.max_prompt_size,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
with timer("Searching knowledge base took", logger):
|
||||
|
|
|
@ -76,6 +76,10 @@ from khoj.processor.conversation.anthropic.anthropic_chat import (
|
|||
anthropic_send_message_to_model,
|
||||
converse_anthropic,
|
||||
)
|
||||
from khoj.processor.conversation.google.gemini_chat import (
|
||||
converse_gemini,
|
||||
gemini_send_message_to_model,
|
||||
)
|
||||
from khoj.processor.conversation.offline.chat_model import (
|
||||
converse_offline,
|
||||
send_message_to_model_offline,
|
||||
|
@ -136,7 +140,7 @@ async def is_ready_to_chat(user: KhojUser):
|
|||
await ConversationAdapters.aget_default_conversation_config()
|
||||
)
|
||||
|
||||
if user_conversation_config and user_conversation_config.model_type == "offline":
|
||||
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
chat_model = user_conversation_config.chat_model
|
||||
max_tokens = user_conversation_config.max_prompt_size
|
||||
if state.offline_chat_processor_config is None:
|
||||
|
@ -146,7 +150,14 @@ async def is_ready_to_chat(user: KhojUser):
|
|||
|
||||
if (
|
||||
user_conversation_config
|
||||
and (user_conversation_config.model_type == "openai" or user_conversation_config.model_type == "anthropic")
|
||||
and (
|
||||
user_conversation_config.model_type
|
||||
in [
|
||||
ChatModelOptions.ModelType.OPENAI,
|
||||
ChatModelOptions.ModelType.ANTHROPIC,
|
||||
ChatModelOptions.ModelType.GOOGLE,
|
||||
]
|
||||
)
|
||||
and user_conversation_config.openai_config
|
||||
):
|
||||
return True
|
||||
|
@ -607,9 +618,10 @@ async def send_message_to_model_wrapper(
|
|||
else conversation_config.max_prompt_size
|
||||
)
|
||||
tokenizer = conversation_config.tokenizer
|
||||
model_type = conversation_config.model_type
|
||||
vision_available = conversation_config.vision_enabled
|
||||
|
||||
if conversation_config.model_type == "offline":
|
||||
if model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
|
||||
|
@ -633,7 +645,7 @@ async def send_message_to_model_wrapper(
|
|||
response_type=response_type,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == "openai":
|
||||
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = conversation_config.openai_config
|
||||
api_key = openai_chat_config.api_key
|
||||
api_base_url = openai_chat_config.api_base_url
|
||||
|
@ -657,7 +669,7 @@ async def send_message_to_model_wrapper(
|
|||
)
|
||||
|
||||
return openai_response
|
||||
elif conversation_config.model_type == "anthropic":
|
||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
|
@ -666,6 +678,7 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
|
@ -674,6 +687,21 @@ async def send_message_to_model_wrapper(
|
|||
api_key=api_key,
|
||||
model=chat_model,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
||||
|
@ -692,7 +720,7 @@ def send_message_to_model_wrapper_sync(
|
|||
max_tokens = conversation_config.max_prompt_size
|
||||
vision_available = conversation_config.vision_enabled
|
||||
|
||||
if conversation_config.model_type == "offline":
|
||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
|
||||
|
@ -714,7 +742,7 @@ def send_message_to_model_wrapper_sync(
|
|||
response_type=response_type,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == "openai":
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
|
@ -730,7 +758,7 @@ def send_message_to_model_wrapper_sync(
|
|||
|
||||
return openai_response
|
||||
|
||||
elif conversation_config.model_type == "anthropic":
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
|
@ -746,6 +774,22 @@ def send_message_to_model_wrapper_sync(
|
|||
api_key=api_key,
|
||||
model=chat_model,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
||||
|
@ -811,7 +855,7 @@ def generate_chat_response(
|
|||
agent=agent,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == "openai":
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = conversation_config.openai_config
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model = conversation_config.chat_model
|
||||
|
@ -834,7 +878,7 @@ def generate_chat_response(
|
|||
vision_available=vision_available,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == "anthropic":
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
chat_response = converse_anthropic(
|
||||
compiled_references,
|
||||
|
@ -851,6 +895,23 @@ def generate_chat_response(
|
|||
user_name=user_name,
|
||||
agent=agent,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
chat_response = converse_gemini(
|
||||
compiled_references,
|
||||
q,
|
||||
online_results,
|
||||
meta_log,
|
||||
model=conversation_config.chat_model,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
tokenizer_name=conversation_config.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
metadata.update({"chat_model": conversation_config.chat_model})
|
||||
|
||||
|
|
Loading…
Reference in a new issue