diff --git a/pyproject.toml b/pyproject.toml index 6e90d65f..b21fc1d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,8 @@ dependencies = [ "cron-descriptor == 1.4.3", "django_apscheduler == 0.6.2", "anthropic == 0.26.1", - "docx2txt == 0.8" + "docx2txt == 0.8", + "google-generativeai == 0.7.2" ] dynamic = ["version"] diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 48f1a6bd..9da94214 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -973,7 +973,7 @@ class ConversationAdapters: if conversation_config is None: conversation_config = ConversationAdapters.get_default_conversation_config() - if conversation_config.model_type == "offline": + if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: chat_model = conversation_config.chat_model max_tokens = conversation_config.max_prompt_size @@ -982,7 +982,12 @@ class ConversationAdapters: return conversation_config if ( - conversation_config.model_type == "openai" or conversation_config.model_type == "anthropic" + conversation_config.model_type + in [ + ChatModelOptions.ModelType.ANTHROPIC, + ChatModelOptions.ModelType.OPENAI, + ChatModelOptions.ModelType.GOOGLE, + ] ) and conversation_config.openai_config: return conversation_config diff --git a/src/khoj/database/migrations/0061_alter_chatmodeloptions_model_type.py b/src/khoj/database/migrations/0061_alter_chatmodeloptions_model_type.py new file mode 100644 index 00000000..c7a602b0 --- /dev/null +++ b/src/khoj/database/migrations/0061_alter_chatmodeloptions_model_type.py @@ -0,0 +1,26 @@ +# Generated by Django 5.0.8 on 2024-09-12 20:06 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0060_merge_20240905_1828"), + ] + + operations = [ + migrations.AlterField( + model_name="chatmodeloptions", + name="model_type", + field=models.CharField( + choices=[ + ("openai", "Openai"), + ("offline", "Offline"), + ("anthropic", "Anthropic"), + ("google", "Google"), + ], + default="offline", + max_length=200, + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 5b995194..80769de8 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -87,6 +87,7 @@ class ChatModelOptions(BaseModel): OPENAI = "openai" OFFLINE = "offline" ANTHROPIC = "anthropic" + GOOGLE = "google" max_prompt_size = models.IntegerField(default=None, null=True, blank=True) subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True) diff --git a/src/khoj/processor/conversation/google/__init__.py b/src/khoj/processor/conversation/google/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py new file mode 100644 index 00000000..64fa5d66 --- /dev/null +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -0,0 +1,221 @@ +import json +import logging +import re +from datetime import datetime, timedelta +from typing import Dict, Optional + +from langchain.schema import ChatMessage + +from khoj.database.models import Agent, KhojUser +from khoj.processor.conversation import prompts +from khoj.processor.conversation.google.utils import ( + gemini_chat_completion_with_backoff, + gemini_completion_with_backoff, +) +from khoj.processor.conversation.utils import generate_chatml_messages_with_context +from khoj.utils.helpers import ConversationCommand, is_none_or_empty +from khoj.utils.rawconfig import LocationData + +logger = logging.getLogger(__name__) + + +def extract_questions_gemini( + text, + model: Optional[str] = "gemini-1.5-flash", + conversation_log={}, + api_key=None, + temperature=0, + max_tokens=None, + location_data: LocationData = None, + user: KhojUser = None, +): + """ + Infer search queries to retrieve relevant notes to answer user query + """ + # Extract Past User Message and Inferred Questions from Conversation Log + location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" + + # Extract Past User Message and Inferred Questions from Conversation Log + chat_history = "".join( + [ + f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n' + for chat in conversation_log.get("chat", [])[-4:] + if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type") + ] + ) + + # Get dates relative to today for prompt creation + today = datetime.today() + current_new_year = today.replace(month=1, day=1) + last_new_year = current_new_year.replace(year=today.year - 1) + + system_prompt = prompts.extract_questions_anthropic_system_prompt.format( + current_date=today.strftime("%Y-%m-%d"), + day_of_week=today.strftime("%A"), + current_month=today.strftime("%Y-%m"), + last_new_year=last_new_year.strftime("%Y"), + last_new_year_date=last_new_year.strftime("%Y-%m-%d"), + current_new_year_date=current_new_year.strftime("%Y-%m-%d"), + yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), + location=location, + username=username, + ) + + prompt = prompts.extract_questions_anthropic_user_message.format( + chat_history=chat_history, + text=text, + ) + + messages = [ChatMessage(content=prompt, role="user")] + + model_kwargs = {"response_mime_type": "application/json"} + + response = gemini_completion_with_backoff( + messages=messages, + system_prompt=system_prompt, + model_name=model, + temperature=temperature, + api_key=api_key, + max_tokens=max_tokens, + model_kwargs=model_kwargs, + ) + + # Extract, Clean Message from Gemini's Response + try: + response = response.strip() + match = re.search(r"\{.*?\}", response) + if match: + response = match.group() + response = json.loads(response) + response = [q.strip() for q in response["queries"] if q.strip()] + if not isinstance(response, list) or not response: + logger.error(f"Invalid response for constructing subqueries: {response}") + return [text] + return response + except: + logger.warning(f"Gemini returned invalid JSON. Falling back to using user message as search query.\n{response}") + questions = [text] + logger.debug(f"Extracted Questions by Gemini: {questions}") + return questions + + +def gemini_send_message_to_model(messages, api_key, model, response_type="text"): + """ + Send message to model + """ + system_prompt = None + if len(messages) == 1: + messages[0].role = "user" + else: + system_prompt = "" + for message in messages.copy(): + if message.role == "system": + system_prompt += message.content + messages.remove(message) + + model_kwargs = {} + if response_type == "json_object": + model_kwargs["response_mime_type"] = "application/json" + + # Get Response from Gemini + return gemini_completion_with_backoff( + messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs + ) + + +def converse_gemini( + references, + user_query, + online_results: Optional[Dict[str, Dict]] = None, + conversation_log={}, + model: Optional[str] = "gemini-1.5-flash", + api_key: Optional[str] = None, + temperature: float = 0.2, + completion_func=None, + conversation_commands=[ConversationCommand.Default], + max_prompt_size=None, + tokenizer_name=None, + location_data: LocationData = None, + user_name: str = None, + agent: Agent = None, +): + """ + Converse with user using Google's Gemini + """ + # Initialize Variables + current_date = datetime.now() + compiled_references = "\n\n".join({f"# {item}" for item in references}) + + conversation_primer = prompts.query_prompt.format(query=user_query) + + if agent and agent.personality: + system_prompt = prompts.custom_personality.format( + name=agent.name, + bio=agent.personality, + current_date=current_date.strftime("%Y-%m-%d"), + day_of_week=current_date.strftime("%A"), + ) + else: + system_prompt = prompts.personality.format( + current_date=current_date.strftime("%Y-%m-%d"), + day_of_week=current_date.strftime("%A"), + ) + + if location_data: + location = f"{location_data.city}, {location_data.region}, {location_data.country}" + location_prompt = prompts.user_location.format(location=location) + system_prompt = f"{system_prompt}\n{location_prompt}" + + if user_name: + user_name_prompt = prompts.user_name.format(name=user_name) + system_prompt = f"{system_prompt}\n{user_name_prompt}" + + # Get Conversation Primer appropriate to Conversation Type + if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references): + completion_func(chat_response=prompts.no_notes_found.format()) + return iter([prompts.no_notes_found.format()]) + elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): + completion_func(chat_response=prompts.no_online_results_found.format()) + return iter([prompts.no_online_results_found.format()]) + + if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: + conversation_primer = ( + f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}" + ) + if not is_none_or_empty(compiled_references): + conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}" + + # Setup Prompt with Primer or Conversation History + messages = generate_chatml_messages_with_context( + conversation_primer, + conversation_log=conversation_log, + model_name=model, + max_prompt_size=max_prompt_size, + tokenizer_name=tokenizer_name, + ) + + for message in messages: + if message.role == "assistant": + message.role = "model" + + for message in messages.copy(): + if message.role == "system": + system_prompt += message.content + messages.remove(message) + + truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) + logger.debug(f"Conversation Context for Gemini: {truncated_messages}") + + # Get Response from Google AI + return gemini_chat_completion_with_backoff( + messages=messages, + compiled_references=references, + online_results=online_results, + model_name=model, + temperature=temperature, + api_key=api_key, + system_prompt=system_prompt, + completion_func=completion_func, + max_prompt_size=max_prompt_size, + ) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py new file mode 100644 index 00000000..4ddf5e2c --- /dev/null +++ b/src/khoj/processor/conversation/google/utils.py @@ -0,0 +1,93 @@ +import logging +from threading import Thread + +import google.generativeai as genai +from tenacity import ( + before_sleep_log, + retry, + stop_after_attempt, + wait_exponential, + wait_random_exponential, +) + +from khoj.processor.conversation.utils import ThreadedGenerator + +logger = logging.getLogger(__name__) + + +DEFAULT_MAX_TOKENS_GEMINI = 8192 + + +@retry( + wait=wait_random_exponential(min=1, max=10), + stop=stop_after_attempt(2), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) +def gemini_completion_with_backoff( + messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None +) -> str: + genai.configure(api_key=api_key) + max_tokens = max_tokens or DEFAULT_MAX_TOKENS_GEMINI + model_kwargs = model_kwargs or dict() + model_kwargs["temperature"] = temperature + model_kwargs["max_output_tokens"] = max_tokens + model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt) + + formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages] + # all messages up to the last are considered to be part of the chat history + chat_session = model.start_chat(history=formatted_messages[0:-1]) + # the last message is considered to be the current prompt + aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0]) + return aggregated_response.text + + +@retry( + wait=wait_exponential(multiplier=1, min=4, max=10), + stop=stop_after_attempt(2), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) +def gemini_chat_completion_with_backoff( + messages, + compiled_references, + online_results, + model_name, + temperature, + api_key, + system_prompt, + max_prompt_size=None, + completion_func=None, + model_kwargs=None, +): + g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) + t = Thread( + target=gemini_llm_thread, + args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs), + ) + t.start() + return g + + +def gemini_llm_thread( + g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None +): + try: + genai.configure(api_key=api_key) + max_tokens = max_prompt_size or DEFAULT_MAX_TOKENS_GEMINI + model_kwargs = model_kwargs or dict() + model_kwargs["temperature"] = temperature + model_kwargs["max_output_tokens"] = max_tokens + model_kwargs["stop_sequences"] = ["Notes:\n["] + model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt) + + formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages] + # all messages up to the last are considered to be part of the chat history + chat_session = model.start_chat(history=formatted_messages[0:-1]) + # the last message is considered to be the current prompt + for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True): + g.send(chunk.text) + except Exception as e: + logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True) + finally: + g.close() diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 755bb6bd..24bf8fdd 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -13,8 +13,8 @@ You were created by Khoj Inc. with the following capabilities: - You *CAN* generate images, look-up real-time information from the internet, set reminders and answer questions based on the user's notes. - Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question. - Make sure to use the specific LaTeX math mode delimiters for your response. LaTex math mode specific delimiters as following - - inline math mode : `\\(` and `\\)` - - display math mode: insert linebreak after opening `$$`, `\\[` and before closing `$$`, `\\]` + - inline math mode : \\( and \\) + - display math mode: insert linebreak after opening $$, \\[ and before closing $$, \\] - Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations. - Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay". - Provide inline references to quotes from the user's notes or any web pages you refer to in your responses in markdown format. For example, "The farmer had ten sheep. [1](https://example.com)". *ALWAYS CITE YOUR SOURCES AND PROVIDE REFERENCES*. Add them inline to directly support your claim. diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index e0f97cd0..c7f7177f 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -7,6 +7,7 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, Tuple, Union import aiohttp +import requests from bs4 import BeautifulSoup from markdownify import markdownify @@ -94,7 +95,7 @@ async def search_online( # Read, extract relevant info from the retrieved web pages if webpages: - webpage_links = [link for link, _, _ in webpages] + webpage_links = set([link for link, _, _ in webpages]) logger.info(f"Reading web pages at: {list(webpage_links)}") if send_status_func: webpage_links_str = "\n- " + "\n- ".join(list(webpage_links)) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index f00c362d..31947c41 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -31,6 +31,7 @@ from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOp from khoj.processor.conversation.anthropic.anthropic_chat import ( extract_questions_anthropic, ) +from khoj.processor.conversation.google.gemini_chat import extract_questions_gemini from khoj.processor.conversation.offline.chat_model import extract_questions_offline from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.openai.gpt import extract_questions @@ -419,6 +420,18 @@ async def extract_references_and_questions( location_data=location_data, user=user, ) + elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: + api_key = conversation_config.openai_config.api_key + chat_model = conversation_config.chat_model + inferred_queries = extract_questions_gemini( + defiltered_query, + model=chat_model, + api_key=api_key, + conversation_log=meta_log, + location_data=location_data, + max_tokens=conversation_config.max_prompt_size, + user=user, + ) # Collate search results as context for GPT with timer("Searching knowledge base took", logger): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e8c9e9ce..5687937a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -76,6 +76,10 @@ from khoj.processor.conversation.anthropic.anthropic_chat import ( anthropic_send_message_to_model, converse_anthropic, ) +from khoj.processor.conversation.google.gemini_chat import ( + converse_gemini, + gemini_send_message_to_model, +) from khoj.processor.conversation.offline.chat_model import ( converse_offline, send_message_to_model_offline, @@ -136,7 +140,7 @@ async def is_ready_to_chat(user: KhojUser): await ConversationAdapters.aget_default_conversation_config() ) - if user_conversation_config and user_conversation_config.model_type == "offline": + if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: chat_model = user_conversation_config.chat_model max_tokens = user_conversation_config.max_prompt_size if state.offline_chat_processor_config is None: @@ -146,7 +150,14 @@ async def is_ready_to_chat(user: KhojUser): if ( user_conversation_config - and (user_conversation_config.model_type == "openai" or user_conversation_config.model_type == "anthropic") + and ( + user_conversation_config.model_type + in [ + ChatModelOptions.ModelType.OPENAI, + ChatModelOptions.ModelType.ANTHROPIC, + ChatModelOptions.ModelType.GOOGLE, + ] + ) and user_conversation_config.openai_config ): return True @@ -607,9 +618,10 @@ async def send_message_to_model_wrapper( else conversation_config.max_prompt_size ) tokenizer = conversation_config.tokenizer + model_type = conversation_config.model_type vision_available = conversation_config.vision_enabled - if conversation_config.model_type == "offline": + if model_type == ChatModelOptions.ModelType.OFFLINE: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) @@ -633,7 +645,7 @@ async def send_message_to_model_wrapper( response_type=response_type, ) - elif conversation_config.model_type == "openai": + elif model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = conversation_config.openai_config api_key = openai_chat_config.api_key api_base_url = openai_chat_config.api_base_url @@ -657,7 +669,7 @@ async def send_message_to_model_wrapper( ) return openai_response - elif conversation_config.model_type == "anthropic": + elif model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key truncated_messages = generate_chatml_messages_with_context( user_message=message, @@ -666,6 +678,7 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, tokenizer_name=tokenizer, vision_enabled=vision_available, + uploaded_image_url=uploaded_image_url, model_type=conversation_config.model_type, ) @@ -674,6 +687,21 @@ async def send_message_to_model_wrapper( api_key=api_key, model=chat_model, ) + elif model_type == ChatModelOptions.ModelType.GOOGLE: + api_key = conversation_config.openai_config.api_key + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model, + max_prompt_size=max_tokens, + tokenizer_name=tokenizer, + vision_enabled=vision_available, + uploaded_image_url=uploaded_image_url, + ) + + return gemini_send_message_to_model( + messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type + ) else: raise HTTPException(status_code=500, detail="Invalid conversation config") @@ -692,7 +720,7 @@ def send_message_to_model_wrapper_sync( max_tokens = conversation_config.max_prompt_size vision_available = conversation_config.vision_enabled - if conversation_config.model_type == "offline": + if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) @@ -714,7 +742,7 @@ def send_message_to_model_wrapper_sync( response_type=response_type, ) - elif conversation_config.model_type == "openai": + elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: api_key = conversation_config.openai_config.api_key truncated_messages = generate_chatml_messages_with_context( user_message=message, @@ -730,7 +758,7 @@ def send_message_to_model_wrapper_sync( return openai_response - elif conversation_config.model_type == "anthropic": + elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key truncated_messages = generate_chatml_messages_with_context( user_message=message, @@ -746,6 +774,22 @@ def send_message_to_model_wrapper_sync( api_key=api_key, model=chat_model, ) + + elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: + api_key = conversation_config.openai_config.api_key + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model, + max_prompt_size=max_tokens, + vision_enabled=vision_available, + ) + + return gemini_send_message_to_model( + messages=truncated_messages, + api_key=api_key, + model=chat_model, + ) else: raise HTTPException(status_code=500, detail="Invalid conversation config") @@ -811,7 +855,7 @@ def generate_chat_response( agent=agent, ) - elif conversation_config.model_type == "openai": + elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = conversation_config.openai_config api_key = openai_chat_config.api_key chat_model = conversation_config.chat_model @@ -834,7 +878,7 @@ def generate_chat_response( vision_available=vision_available, ) - elif conversation_config.model_type == "anthropic": + elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key chat_response = converse_anthropic( compiled_references, @@ -851,6 +895,23 @@ def generate_chat_response( user_name=user_name, agent=agent, ) + elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: + api_key = conversation_config.openai_config.api_key + chat_response = converse_gemini( + compiled_references, + q, + online_results, + meta_log, + model=conversation_config.chat_model, + api_key=api_key, + completion_func=partial_completion, + conversation_commands=conversation_commands, + max_prompt_size=conversation_config.max_prompt_size, + tokenizer_name=conversation_config.tokenizer, + location_data=location_data, + user_name=user_name, + agent=agent, + ) metadata.update({"chat_model": conversation_config.chat_model})