Support Google's Gemini model series (#902)

* Add functions to chat with Google's gemini model series
  * Gracefully close thread when there's an exception in the gemini llm thread
* Use enums for verifying the chat model option type
* Add a migration to add the gemini chat model type to the db model
* Fix chat model selection verification and math prompt tuning
* Fix extract questions method with gemini. Enforce json response in extract questions.
* Add standard stop sequence for Gemini chat response generation

---------

Co-authored-by: sabaimran <narmiabas@gmail.com>
Co-authored-by: Debanjum Singh Solanky <debanjum@gmail.com>
This commit is contained in:
Alexander Matyasko 2024-09-13 09:17:55 +08:00 committed by GitHub
parent 42b727e926
commit 9570933506
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 438 additions and 16 deletions

View file

@ -87,7 +87,8 @@ dependencies = [
"cron-descriptor == 1.4.3",
"django_apscheduler == 0.6.2",
"anthropic == 0.26.1",
"docx2txt == 0.8"
"docx2txt == 0.8",
"google-generativeai == 0.7.2"
]
dynamic = ["version"]

View file

@ -973,7 +973,7 @@ class ConversationAdapters:
if conversation_config is None:
conversation_config = ConversationAdapters.get_default_conversation_config()
if conversation_config.model_type == "offline":
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
@ -982,7 +982,12 @@ class ConversationAdapters:
return conversation_config
if (
conversation_config.model_type == "openai" or conversation_config.model_type == "anthropic"
conversation_config.model_type
in [
ChatModelOptions.ModelType.ANTHROPIC,
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE,
]
) and conversation_config.openai_config:
return conversation_config

View file

@ -0,0 +1,26 @@
# Generated by Django 5.0.8 on 2024-09-12 20:06
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0060_merge_20240905_1828"),
]
operations = [
migrations.AlterField(
model_name="chatmodeloptions",
name="model_type",
field=models.CharField(
choices=[
("openai", "Openai"),
("offline", "Offline"),
("anthropic", "Anthropic"),
("google", "Google"),
],
default="offline",
max_length=200,
),
),
]

View file

@ -87,6 +87,7 @@ class ChatModelOptions(BaseModel):
OPENAI = "openai"
OFFLINE = "offline"
ANTHROPIC = "anthropic"
GOOGLE = "google"
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)

View file

@ -0,0 +1,221 @@
import json
import logging
import re
from datetime import datetime, timedelta
from typing import Dict, Optional
from langchain.schema import ChatMessage
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
gemini_chat_completion_with_backoff,
gemini_completion_with_backoff,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
def extract_questions_gemini(
text,
model: Optional[str] = "gemini-1.5-flash",
conversation_log={},
api_key=None,
temperature=0,
max_tokens=None,
location_data: LocationData = None,
user: KhojUser = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
[
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type")
]
)
# Get dates relative to today for prompt creation
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
)
prompt = prompts.extract_questions_anthropic_user_message.format(
chat_history=chat_history,
text=text,
)
messages = [ChatMessage(content=prompt, role="user")]
model_kwargs = {"response_mime_type": "application/json"}
response = gemini_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
temperature=temperature,
api_key=api_key,
max_tokens=max_tokens,
model_kwargs=model_kwargs,
)
# Extract, Clean Message from Gemini's Response
try:
response = response.strip()
match = re.search(r"\{.*?\}", response)
if match:
response = match.group()
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [text]
return response
except:
logger.warning(f"Gemini returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text]
logger.debug(f"Extracted Questions by Gemini: {questions}")
return questions
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
"""
Send message to model
"""
system_prompt = None
if len(messages) == 1:
messages[0].role = "user"
else:
system_prompt = ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
model_kwargs = {}
if response_type == "json_object":
model_kwargs["response_mime_type"] = "application/json"
# Get Response from Gemini
return gemini_completion_with_backoff(
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
)
def converse_gemini(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: Optional[str] = "gemini-1.5-flash",
api_key: Optional[str] = None,
temperature: float = 0.2,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
):
"""
Converse with user using Google's Gemini
"""
# Initialize Variables
current_date = datetime.now()
compiled_references = "\n\n".join({f"# {item}" for item in references})
conversation_primer = prompts.query_prompt.format(query=user_query)
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
name=agent.name,
bio=agent.personality,
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
else:
system_prompt = prompts.personality.format(
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
conversation_primer = (
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
)
if not is_none_or_empty(compiled_references):
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
conversation_log=conversation_log,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
)
for message in messages:
if message.role == "assistant":
message.role = "model"
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for Gemini: {truncated_messages}")
# Get Response from Google AI
return gemini_chat_completion_with_backoff(
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,
temperature=temperature,
api_key=api_key,
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,
)

View file

@ -0,0 +1,93 @@
import logging
from threading import Thread
import google.generativeai as genai
from tenacity import (
before_sleep_log,
retry,
stop_after_attempt,
wait_exponential,
wait_random_exponential,
)
from khoj.processor.conversation.utils import ThreadedGenerator
logger = logging.getLogger(__name__)
DEFAULT_MAX_TOKENS_GEMINI = 8192
@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
) -> str:
genai.configure(api_key=api_key)
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_GEMINI
model_kwargs = model_kwargs or dict()
model_kwargs["temperature"] = temperature
model_kwargs["max_output_tokens"] = max_tokens
model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt)
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
# all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
return aggregated_response.text
@retry(
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def gemini_chat_completion_with_backoff(
messages,
compiled_references,
online_results,
model_name,
temperature,
api_key,
system_prompt,
max_prompt_size=None,
completion_func=None,
model_kwargs=None,
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
)
t.start()
return g
def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
):
try:
genai.configure(api_key=api_key)
max_tokens = max_prompt_size or DEFAULT_MAX_TOKENS_GEMINI
model_kwargs = model_kwargs or dict()
model_kwargs["temperature"] = temperature
model_kwargs["max_output_tokens"] = max_tokens
model_kwargs["stop_sequences"] = ["Notes:\n["]
model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt)
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
# all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
g.send(chunk.text)
except Exception as e:
logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
finally:
g.close()

View file

@ -13,8 +13,8 @@ You were created by Khoj Inc. with the following capabilities:
- You *CAN* generate images, look-up real-time information from the internet, set reminders and answer questions based on the user's notes.
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
- Make sure to use the specific LaTeX math mode delimiters for your response. LaTex math mode specific delimiters as following
- inline math mode : `\\(` and `\\)`
- display math mode: insert linebreak after opening `$$`, `\\[` and before closing `$$`, `\\]`
- inline math mode : \\( and \\)
- display math mode: insert linebreak after opening $$, \\[ and before closing $$, \\]
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
- Provide inline references to quotes from the user's notes or any web pages you refer to in your responses in markdown format. For example, "The farmer had ten sheep. [1](https://example.com)". *ALWAYS CITE YOUR SOURCES AND PROVIDE REFERENCES*. Add them inline to directly support your claim.

View file

@ -7,6 +7,7 @@ from collections import defaultdict
from typing import Callable, Dict, List, Optional, Tuple, Union
import aiohttp
import requests
from bs4 import BeautifulSoup
from markdownify import markdownify
@ -94,7 +95,7 @@ async def search_online(
# Read, extract relevant info from the retrieved web pages
if webpages:
webpage_links = [link for link, _, _ in webpages]
webpage_links = set([link for link, _, _ in webpages])
logger.info(f"Reading web pages at: {list(webpage_links)}")
if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))

View file

@ -31,6 +31,7 @@ from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOp
from khoj.processor.conversation.anthropic.anthropic_chat import (
extract_questions_anthropic,
)
from khoj.processor.conversation.google.gemini_chat import extract_questions_gemini
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions
@ -419,6 +420,18 @@ async def extract_references_and_questions(
location_data=location_data,
user=user,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
chat_model = conversation_config.chat_model
inferred_queries = extract_questions_gemini(
defiltered_query,
model=chat_model,
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
max_tokens=conversation_config.max_prompt_size,
user=user,
)
# Collate search results as context for GPT
with timer("Searching knowledge base took", logger):

View file

@ -76,6 +76,10 @@ from khoj.processor.conversation.anthropic.anthropic_chat import (
anthropic_send_message_to_model,
converse_anthropic,
)
from khoj.processor.conversation.google.gemini_chat import (
converse_gemini,
gemini_send_message_to_model,
)
from khoj.processor.conversation.offline.chat_model import (
converse_offline,
send_message_to_model_offline,
@ -136,7 +140,7 @@ async def is_ready_to_chat(user: KhojUser):
await ConversationAdapters.aget_default_conversation_config()
)
if user_conversation_config and user_conversation_config.model_type == "offline":
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
chat_model = user_conversation_config.chat_model
max_tokens = user_conversation_config.max_prompt_size
if state.offline_chat_processor_config is None:
@ -146,7 +150,14 @@ async def is_ready_to_chat(user: KhojUser):
if (
user_conversation_config
and (user_conversation_config.model_type == "openai" or user_conversation_config.model_type == "anthropic")
and (
user_conversation_config.model_type
in [
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.ANTHROPIC,
ChatModelOptions.ModelType.GOOGLE,
]
)
and user_conversation_config.openai_config
):
return True
@ -607,9 +618,10 @@ async def send_message_to_model_wrapper(
else conversation_config.max_prompt_size
)
tokenizer = conversation_config.tokenizer
model_type = conversation_config.model_type
vision_available = conversation_config.vision_enabled
if conversation_config.model_type == "offline":
if model_type == ChatModelOptions.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
@ -633,7 +645,7 @@ async def send_message_to_model_wrapper(
response_type=response_type,
)
elif conversation_config.model_type == "openai":
elif model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key
api_base_url = openai_chat_config.api_base_url
@ -657,7 +669,7 @@ async def send_message_to_model_wrapper(
)
return openai_response
elif conversation_config.model_type == "anthropic":
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
@ -666,6 +678,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
model_type=conversation_config.model_type,
)
@ -674,6 +687,21 @@ async def send_message_to_model_wrapper(
api_key=api_key,
model=chat_model,
)
elif model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
)
return gemini_send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -692,7 +720,7 @@ def send_message_to_model_wrapper_sync(
max_tokens = conversation_config.max_prompt_size
vision_available = conversation_config.vision_enabled
if conversation_config.model_type == "offline":
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
@ -714,7 +742,7 @@ def send_message_to_model_wrapper_sync(
response_type=response_type,
)
elif conversation_config.model_type == "openai":
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
@ -730,7 +758,7 @@ def send_message_to_model_wrapper_sync(
return openai_response
elif conversation_config.model_type == "anthropic":
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
@ -746,6 +774,22 @@ def send_message_to_model_wrapper_sync(
api_key=api_key,
model=chat_model,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
)
return gemini_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -811,7 +855,7 @@ def generate_chat_response(
agent=agent,
)
elif conversation_config.model_type == "openai":
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model
@ -834,7 +878,7 @@ def generate_chat_response(
vision_available=vision_available,
)
elif conversation_config.model_type == "anthropic":
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
chat_response = converse_anthropic(
compiled_references,
@ -851,6 +895,23 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
chat_response = converse_gemini(
compiled_references,
q,
online_results,
meta_log,
model=conversation_config.chat_model,
api_key=api_key,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
)
metadata.update({"chat_model": conversation_config.chat_model})