mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Support Google's Gemini model series (#902)
* Add functions to chat with Google's gemini model series * Gracefully close thread when there's an exception in the gemini llm thread * Use enums for verifying the chat model option type * Add a migration to add the gemini chat model type to the db model * Fix chat model selection verification and math prompt tuning * Fix extract questions method with gemini. Enforce json response in extract questions. * Add standard stop sequence for Gemini chat response generation --------- Co-authored-by: sabaimran <narmiabas@gmail.com> Co-authored-by: Debanjum Singh Solanky <debanjum@gmail.com>
This commit is contained in:
parent
42b727e926
commit
9570933506
11 changed files with 438 additions and 16 deletions
|
@ -87,7 +87,8 @@ dependencies = [
|
||||||
"cron-descriptor == 1.4.3",
|
"cron-descriptor == 1.4.3",
|
||||||
"django_apscheduler == 0.6.2",
|
"django_apscheduler == 0.6.2",
|
||||||
"anthropic == 0.26.1",
|
"anthropic == 0.26.1",
|
||||||
"docx2txt == 0.8"
|
"docx2txt == 0.8",
|
||||||
|
"google-generativeai == 0.7.2"
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
|
|
@ -973,7 +973,7 @@ class ConversationAdapters:
|
||||||
if conversation_config is None:
|
if conversation_config is None:
|
||||||
conversation_config = ConversationAdapters.get_default_conversation_config()
|
conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||||
|
|
||||||
if conversation_config.model_type == "offline":
|
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
max_tokens = conversation_config.max_prompt_size
|
max_tokens = conversation_config.max_prompt_size
|
||||||
|
@ -982,7 +982,12 @@ class ConversationAdapters:
|
||||||
return conversation_config
|
return conversation_config
|
||||||
|
|
||||||
if (
|
if (
|
||||||
conversation_config.model_type == "openai" or conversation_config.model_type == "anthropic"
|
conversation_config.model_type
|
||||||
|
in [
|
||||||
|
ChatModelOptions.ModelType.ANTHROPIC,
|
||||||
|
ChatModelOptions.ModelType.OPENAI,
|
||||||
|
ChatModelOptions.ModelType.GOOGLE,
|
||||||
|
]
|
||||||
) and conversation_config.openai_config:
|
) and conversation_config.openai_config:
|
||||||
return conversation_config
|
return conversation_config
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Generated by Django 5.0.8 on 2024-09-12 20:06
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0060_merge_20240905_1828"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="chatmodeloptions",
|
||||||
|
name="model_type",
|
||||||
|
field=models.CharField(
|
||||||
|
choices=[
|
||||||
|
("openai", "Openai"),
|
||||||
|
("offline", "Offline"),
|
||||||
|
("anthropic", "Anthropic"),
|
||||||
|
("google", "Google"),
|
||||||
|
],
|
||||||
|
default="offline",
|
||||||
|
max_length=200,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -87,6 +87,7 @@ class ChatModelOptions(BaseModel):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
OFFLINE = "offline"
|
OFFLINE = "offline"
|
||||||
ANTHROPIC = "anthropic"
|
ANTHROPIC = "anthropic"
|
||||||
|
GOOGLE = "google"
|
||||||
|
|
||||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||||
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||||
|
|
0
src/khoj/processor/conversation/google/__init__.py
Normal file
0
src/khoj/processor/conversation/google/__init__.py
Normal file
221
src/khoj/processor/conversation/google/gemini_chat.py
Normal file
221
src/khoj/processor/conversation/google/gemini_chat.py
Normal file
|
@ -0,0 +1,221 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
|
from khoj.database.models import Agent, KhojUser
|
||||||
|
from khoj.processor.conversation import prompts
|
||||||
|
from khoj.processor.conversation.google.utils import (
|
||||||
|
gemini_chat_completion_with_backoff,
|
||||||
|
gemini_completion_with_backoff,
|
||||||
|
)
|
||||||
|
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||||
|
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||||
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_questions_gemini(
|
||||||
|
text,
|
||||||
|
model: Optional[str] = "gemini-1.5-flash",
|
||||||
|
conversation_log={},
|
||||||
|
api_key=None,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=None,
|
||||||
|
location_data: LocationData = None,
|
||||||
|
user: KhojUser = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
"""
|
||||||
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
|
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||||
|
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||||
|
|
||||||
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
|
chat_history = "".join(
|
||||||
|
[
|
||||||
|
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
|
||||||
|
for chat in conversation_log.get("chat", [])[-4:]
|
||||||
|
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get dates relative to today for prompt creation
|
||||||
|
today = datetime.today()
|
||||||
|
current_new_year = today.replace(month=1, day=1)
|
||||||
|
last_new_year = current_new_year.replace(year=today.year - 1)
|
||||||
|
|
||||||
|
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
|
||||||
|
current_date=today.strftime("%Y-%m-%d"),
|
||||||
|
day_of_week=today.strftime("%A"),
|
||||||
|
current_month=today.strftime("%Y-%m"),
|
||||||
|
last_new_year=last_new_year.strftime("%Y"),
|
||||||
|
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
|
||||||
|
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||||
|
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||||
|
location=location,
|
||||||
|
username=username,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = prompts.extract_questions_anthropic_user_message.format(
|
||||||
|
chat_history=chat_history,
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [ChatMessage(content=prompt, role="user")]
|
||||||
|
|
||||||
|
model_kwargs = {"response_mime_type": "application/json"}
|
||||||
|
|
||||||
|
response = gemini_completion_with_backoff(
|
||||||
|
messages=messages,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
model_name=model,
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=api_key,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract, Clean Message from Gemini's Response
|
||||||
|
try:
|
||||||
|
response = response.strip()
|
||||||
|
match = re.search(r"\{.*?\}", response)
|
||||||
|
if match:
|
||||||
|
response = match.group()
|
||||||
|
response = json.loads(response)
|
||||||
|
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||||
|
if not isinstance(response, list) or not response:
|
||||||
|
logger.error(f"Invalid response for constructing subqueries: {response}")
|
||||||
|
return [text]
|
||||||
|
return response
|
||||||
|
except:
|
||||||
|
logger.warning(f"Gemini returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||||
|
questions = [text]
|
||||||
|
logger.debug(f"Extracted Questions by Gemini: {questions}")
|
||||||
|
return questions
|
||||||
|
|
||||||
|
|
||||||
|
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
|
||||||
|
"""
|
||||||
|
Send message to model
|
||||||
|
"""
|
||||||
|
system_prompt = None
|
||||||
|
if len(messages) == 1:
|
||||||
|
messages[0].role = "user"
|
||||||
|
else:
|
||||||
|
system_prompt = ""
|
||||||
|
for message in messages.copy():
|
||||||
|
if message.role == "system":
|
||||||
|
system_prompt += message.content
|
||||||
|
messages.remove(message)
|
||||||
|
|
||||||
|
model_kwargs = {}
|
||||||
|
if response_type == "json_object":
|
||||||
|
model_kwargs["response_mime_type"] = "application/json"
|
||||||
|
|
||||||
|
# Get Response from Gemini
|
||||||
|
return gemini_completion_with_backoff(
|
||||||
|
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def converse_gemini(
|
||||||
|
references,
|
||||||
|
user_query,
|
||||||
|
online_results: Optional[Dict[str, Dict]] = None,
|
||||||
|
conversation_log={},
|
||||||
|
model: Optional[str] = "gemini-1.5-flash",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
temperature: float = 0.2,
|
||||||
|
completion_func=None,
|
||||||
|
conversation_commands=[ConversationCommand.Default],
|
||||||
|
max_prompt_size=None,
|
||||||
|
tokenizer_name=None,
|
||||||
|
location_data: LocationData = None,
|
||||||
|
user_name: str = None,
|
||||||
|
agent: Agent = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Converse with user using Google's Gemini
|
||||||
|
"""
|
||||||
|
# Initialize Variables
|
||||||
|
current_date = datetime.now()
|
||||||
|
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
||||||
|
|
||||||
|
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||||
|
|
||||||
|
if agent and agent.personality:
|
||||||
|
system_prompt = prompts.custom_personality.format(
|
||||||
|
name=agent.name,
|
||||||
|
bio=agent.personality,
|
||||||
|
current_date=current_date.strftime("%Y-%m-%d"),
|
||||||
|
day_of_week=current_date.strftime("%A"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
system_prompt = prompts.personality.format(
|
||||||
|
current_date=current_date.strftime("%Y-%m-%d"),
|
||||||
|
day_of_week=current_date.strftime("%A"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if location_data:
|
||||||
|
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||||
|
location_prompt = prompts.user_location.format(location=location)
|
||||||
|
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||||
|
|
||||||
|
if user_name:
|
||||||
|
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||||
|
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||||
|
|
||||||
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||||
|
completion_func(chat_response=prompts.no_notes_found.format())
|
||||||
|
return iter([prompts.no_notes_found.format()])
|
||||||
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
|
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||||
|
return iter([prompts.no_online_results_found.format()])
|
||||||
|
|
||||||
|
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||||
|
conversation_primer = (
|
||||||
|
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
||||||
|
)
|
||||||
|
if not is_none_or_empty(compiled_references):
|
||||||
|
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
||||||
|
|
||||||
|
# Setup Prompt with Primer or Conversation History
|
||||||
|
messages = generate_chatml_messages_with_context(
|
||||||
|
conversation_primer,
|
||||||
|
conversation_log=conversation_log,
|
||||||
|
model_name=model,
|
||||||
|
max_prompt_size=max_prompt_size,
|
||||||
|
tokenizer_name=tokenizer_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if message.role == "assistant":
|
||||||
|
message.role = "model"
|
||||||
|
|
||||||
|
for message in messages.copy():
|
||||||
|
if message.role == "system":
|
||||||
|
system_prompt += message.content
|
||||||
|
messages.remove(message)
|
||||||
|
|
||||||
|
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
||||||
|
logger.debug(f"Conversation Context for Gemini: {truncated_messages}")
|
||||||
|
|
||||||
|
# Get Response from Google AI
|
||||||
|
return gemini_chat_completion_with_backoff(
|
||||||
|
messages=messages,
|
||||||
|
compiled_references=references,
|
||||||
|
online_results=online_results,
|
||||||
|
model_name=model,
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=api_key,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
completion_func=completion_func,
|
||||||
|
max_prompt_size=max_prompt_size,
|
||||||
|
)
|
93
src/khoj/processor/conversation/google/utils.py
Normal file
93
src/khoj/processor/conversation/google/utils.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
import logging
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
import google.generativeai as genai
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
wait_random_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MAX_TOKENS_GEMINI = 8192
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
|
stop=stop_after_attempt(2),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def gemini_completion_with_backoff(
|
||||||
|
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
|
||||||
|
) -> str:
|
||||||
|
genai.configure(api_key=api_key)
|
||||||
|
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_GEMINI
|
||||||
|
model_kwargs = model_kwargs or dict()
|
||||||
|
model_kwargs["temperature"] = temperature
|
||||||
|
model_kwargs["max_output_tokens"] = max_tokens
|
||||||
|
model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt)
|
||||||
|
|
||||||
|
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||||
|
# all messages up to the last are considered to be part of the chat history
|
||||||
|
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||||
|
# the last message is considered to be the current prompt
|
||||||
|
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
|
||||||
|
return aggregated_response.text
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
stop=stop_after_attempt(2),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def gemini_chat_completion_with_backoff(
|
||||||
|
messages,
|
||||||
|
compiled_references,
|
||||||
|
online_results,
|
||||||
|
model_name,
|
||||||
|
temperature,
|
||||||
|
api_key,
|
||||||
|
system_prompt,
|
||||||
|
max_prompt_size=None,
|
||||||
|
completion_func=None,
|
||||||
|
model_kwargs=None,
|
||||||
|
):
|
||||||
|
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||||
|
t = Thread(
|
||||||
|
target=gemini_llm_thread,
|
||||||
|
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def gemini_llm_thread(
|
||||||
|
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
genai.configure(api_key=api_key)
|
||||||
|
max_tokens = max_prompt_size or DEFAULT_MAX_TOKENS_GEMINI
|
||||||
|
model_kwargs = model_kwargs or dict()
|
||||||
|
model_kwargs["temperature"] = temperature
|
||||||
|
model_kwargs["max_output_tokens"] = max_tokens
|
||||||
|
model_kwargs["stop_sequences"] = ["Notes:\n["]
|
||||||
|
model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt)
|
||||||
|
|
||||||
|
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||||
|
# all messages up to the last are considered to be part of the chat history
|
||||||
|
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||||
|
# the last message is considered to be the current prompt
|
||||||
|
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
|
||||||
|
g.send(chunk.text)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
g.close()
|
|
@ -13,8 +13,8 @@ You were created by Khoj Inc. with the following capabilities:
|
||||||
- You *CAN* generate images, look-up real-time information from the internet, set reminders and answer questions based on the user's notes.
|
- You *CAN* generate images, look-up real-time information from the internet, set reminders and answer questions based on the user's notes.
|
||||||
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
||||||
- Make sure to use the specific LaTeX math mode delimiters for your response. LaTex math mode specific delimiters as following
|
- Make sure to use the specific LaTeX math mode delimiters for your response. LaTex math mode specific delimiters as following
|
||||||
- inline math mode : `\\(` and `\\)`
|
- inline math mode : \\( and \\)
|
||||||
- display math mode: insert linebreak after opening `$$`, `\\[` and before closing `$$`, `\\]`
|
- display math mode: insert linebreak after opening $$, \\[ and before closing $$, \\]
|
||||||
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
||||||
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
||||||
- Provide inline references to quotes from the user's notes or any web pages you refer to in your responses in markdown format. For example, "The farmer had ten sheep. [1](https://example.com)". *ALWAYS CITE YOUR SOURCES AND PROVIDE REFERENCES*. Add them inline to directly support your claim.
|
- Provide inline references to quotes from the user's notes or any web pages you refer to in your responses in markdown format. For example, "The farmer had ten sheep. [1](https://example.com)". *ALWAYS CITE YOUR SOURCES AND PROVIDE REFERENCES*. Add them inline to directly support your claim.
|
||||||
|
|
|
@ -7,6 +7,7 @@ from collections import defaultdict
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from markdownify import markdownify
|
from markdownify import markdownify
|
||||||
|
|
||||||
|
@ -94,7 +95,7 @@ async def search_online(
|
||||||
|
|
||||||
# Read, extract relevant info from the retrieved web pages
|
# Read, extract relevant info from the retrieved web pages
|
||||||
if webpages:
|
if webpages:
|
||||||
webpage_links = [link for link, _, _ in webpages]
|
webpage_links = set([link for link, _, _ in webpages])
|
||||||
logger.info(f"Reading web pages at: {list(webpage_links)}")
|
logger.info(f"Reading web pages at: {list(webpage_links)}")
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
||||||
|
|
|
@ -31,6 +31,7 @@ from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOp
|
||||||
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||||
extract_questions_anthropic,
|
extract_questions_anthropic,
|
||||||
)
|
)
|
||||||
|
from khoj.processor.conversation.google.gemini_chat import extract_questions_gemini
|
||||||
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
|
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
|
||||||
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
||||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
|
@ -419,6 +420,18 @@ async def extract_references_and_questions(
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
chat_model = conversation_config.chat_model
|
||||||
|
inferred_queries = extract_questions_gemini(
|
||||||
|
defiltered_query,
|
||||||
|
model=chat_model,
|
||||||
|
api_key=api_key,
|
||||||
|
conversation_log=meta_log,
|
||||||
|
location_data=location_data,
|
||||||
|
max_tokens=conversation_config.max_prompt_size,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
# Collate search results as context for GPT
|
# Collate search results as context for GPT
|
||||||
with timer("Searching knowledge base took", logger):
|
with timer("Searching knowledge base took", logger):
|
||||||
|
|
|
@ -76,6 +76,10 @@ from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||||
anthropic_send_message_to_model,
|
anthropic_send_message_to_model,
|
||||||
converse_anthropic,
|
converse_anthropic,
|
||||||
)
|
)
|
||||||
|
from khoj.processor.conversation.google.gemini_chat import (
|
||||||
|
converse_gemini,
|
||||||
|
gemini_send_message_to_model,
|
||||||
|
)
|
||||||
from khoj.processor.conversation.offline.chat_model import (
|
from khoj.processor.conversation.offline.chat_model import (
|
||||||
converse_offline,
|
converse_offline,
|
||||||
send_message_to_model_offline,
|
send_message_to_model_offline,
|
||||||
|
@ -136,7 +140,7 @@ async def is_ready_to_chat(user: KhojUser):
|
||||||
await ConversationAdapters.aget_default_conversation_config()
|
await ConversationAdapters.aget_default_conversation_config()
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_conversation_config and user_conversation_config.model_type == "offline":
|
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
chat_model = user_conversation_config.chat_model
|
chat_model = user_conversation_config.chat_model
|
||||||
max_tokens = user_conversation_config.max_prompt_size
|
max_tokens = user_conversation_config.max_prompt_size
|
||||||
if state.offline_chat_processor_config is None:
|
if state.offline_chat_processor_config is None:
|
||||||
|
@ -146,7 +150,14 @@ async def is_ready_to_chat(user: KhojUser):
|
||||||
|
|
||||||
if (
|
if (
|
||||||
user_conversation_config
|
user_conversation_config
|
||||||
and (user_conversation_config.model_type == "openai" or user_conversation_config.model_type == "anthropic")
|
and (
|
||||||
|
user_conversation_config.model_type
|
||||||
|
in [
|
||||||
|
ChatModelOptions.ModelType.OPENAI,
|
||||||
|
ChatModelOptions.ModelType.ANTHROPIC,
|
||||||
|
ChatModelOptions.ModelType.GOOGLE,
|
||||||
|
]
|
||||||
|
)
|
||||||
and user_conversation_config.openai_config
|
and user_conversation_config.openai_config
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
@ -607,9 +618,10 @@ async def send_message_to_model_wrapper(
|
||||||
else conversation_config.max_prompt_size
|
else conversation_config.max_prompt_size
|
||||||
)
|
)
|
||||||
tokenizer = conversation_config.tokenizer
|
tokenizer = conversation_config.tokenizer
|
||||||
|
model_type = conversation_config.model_type
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = conversation_config.vision_enabled
|
||||||
|
|
||||||
if conversation_config.model_type == "offline":
|
if model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
|
|
||||||
|
@ -633,7 +645,7 @@ async def send_message_to_model_wrapper(
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == "openai":
|
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
openai_chat_config = conversation_config.openai_config
|
openai_chat_config = conversation_config.openai_config
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
api_base_url = openai_chat_config.api_base_url
|
api_base_url = openai_chat_config.api_base_url
|
||||||
|
@ -657,7 +669,7 @@ async def send_message_to_model_wrapper(
|
||||||
)
|
)
|
||||||
|
|
||||||
return openai_response
|
return openai_response
|
||||||
elif conversation_config.model_type == "anthropic":
|
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message,
|
user_message=message,
|
||||||
|
@ -666,6 +678,7 @@ async def send_message_to_model_wrapper(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
tokenizer_name=tokenizer,
|
tokenizer_name=tokenizer,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
model_type=conversation_config.model_type,
|
model_type=conversation_config.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -674,6 +687,21 @@ async def send_message_to_model_wrapper(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
)
|
)
|
||||||
|
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
tokenizer_name=tokenizer,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
return gemini_send_message_to_model(
|
||||||
|
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
|
||||||
|
@ -692,7 +720,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
max_tokens = conversation_config.max_prompt_size
|
max_tokens = conversation_config.max_prompt_size
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = conversation_config.vision_enabled
|
||||||
|
|
||||||
if conversation_config.model_type == "offline":
|
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
|
|
||||||
|
@ -714,7 +742,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == "openai":
|
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message,
|
user_message=message,
|
||||||
|
@ -730,7 +758,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
|
|
||||||
return openai_response
|
return openai_response
|
||||||
|
|
||||||
elif conversation_config.model_type == "anthropic":
|
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message,
|
user_message=message,
|
||||||
|
@ -746,6 +774,22 @@ def send_message_to_model_wrapper_sync(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
return gemini_send_message_to_model(
|
||||||
|
messages=truncated_messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model=chat_model,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
|
||||||
|
@ -811,7 +855,7 @@ def generate_chat_response(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == "openai":
|
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
openai_chat_config = conversation_config.openai_config
|
openai_chat_config = conversation_config.openai_config
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
|
@ -834,7 +878,7 @@ def generate_chat_response(
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == "anthropic":
|
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
chat_response = converse_anthropic(
|
chat_response = converse_anthropic(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
|
@ -851,6 +895,23 @@ def generate_chat_response(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
chat_response = converse_gemini(
|
||||||
|
compiled_references,
|
||||||
|
q,
|
||||||
|
online_results,
|
||||||
|
meta_log,
|
||||||
|
model=conversation_config.chat_model,
|
||||||
|
api_key=api_key,
|
||||||
|
completion_func=partial_completion,
|
||||||
|
conversation_commands=conversation_commands,
|
||||||
|
max_prompt_size=conversation_config.max_prompt_size,
|
||||||
|
tokenizer_name=conversation_config.tokenizer,
|
||||||
|
location_data=location_data,
|
||||||
|
user_name=user_name,
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
metadata.update({"chat_model": conversation_config.chat_model})
|
metadata.update({"chat_model": conversation_config.chat_model})
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue