From 01cdc54ad048d19234a0d838908c191dce17cb53 Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Sun, 26 May 2024 22:50:34 +0530 Subject: [PATCH] Add support for Anthropic models (#760) * Add support for chatting with Anthropic's suite of models - Had to use a custom class because there was enough nuance with how the anthropic SDK works that it would be better to simply separate out the logic. The extract questions flow needed modification of the system prompt in order to work as intended with the haiku model --- pyproject.toml | 1 + src/khoj/database/adapters/__init__.py | 4 +- .../0043_alter_chatmodeloptions_model_type.py | 21 ++ src/khoj/database/models/__init__.py | 1 + .../conversation/anthropic/__init__.py | 0 .../conversation/anthropic/anthropic_chat.py | 204 ++++++++++++++++++ .../processor/conversation/anthropic/utils.py | 116 ++++++++++ src/khoj/processor/conversation/prompts.py | 39 ++++ src/khoj/routers/api.py | 16 +- src/khoj/routers/helpers.py | 57 ++++- 10 files changed, 454 insertions(+), 5 deletions(-) create mode 100644 src/khoj/database/migrations/0043_alter_chatmodeloptions_model_type.py create mode 100644 src/khoj/processor/conversation/anthropic/__init__.py create mode 100644 src/khoj/processor/conversation/anthropic/anthropic_chat.py create mode 100644 src/khoj/processor/conversation/anthropic/utils.py diff --git a/pyproject.toml b/pyproject.toml index c7209b58..0f04d8a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,7 @@ dependencies = [ "pytz ~= 2024.1", "cron-descriptor == 1.4.3", "django_apscheduler == 0.6.2", + "anthropic == 0.26.1", ] dynamic = ["version"] diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 33107d2d..bc964c3f 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -833,7 +833,9 @@ class ConversationAdapters: return conversation_config - if conversation_config.model_type == "openai" and conversation_config.openai_config: + if ( + conversation_config.model_type == "openai" or conversation_config.model_type == "anthropic" + ) and conversation_config.openai_config: return conversation_config else: diff --git a/src/khoj/database/migrations/0043_alter_chatmodeloptions_model_type.py b/src/khoj/database/migrations/0043_alter_chatmodeloptions_model_type.py new file mode 100644 index 00000000..246c78be --- /dev/null +++ b/src/khoj/database/migrations/0043_alter_chatmodeloptions_model_type.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.10 on 2024-05-26 12:35 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0042_serverchatsettings"), + ] + + operations = [ + migrations.AlterField( + model_name="chatmodeloptions", + name="model_type", + field=models.CharField( + choices=[("openai", "Openai"), ("offline", "Offline"), ("anthropic", "Anthropic")], + default="offline", + max_length=200, + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 64741018..def10c0a 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -84,6 +84,7 @@ class ChatModelOptions(BaseModel): class ModelType(models.TextChoices): OPENAI = "openai" OFFLINE = "offline" + ANTHROPIC = "anthropic" max_prompt_size = models.IntegerField(default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) diff --git a/src/khoj/processor/conversation/anthropic/__init__.py b/src/khoj/processor/conversation/anthropic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py new file mode 100644 index 00000000..28b1e0f0 --- /dev/null +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -0,0 +1,204 @@ +import json +import logging +import re +from datetime import datetime, timedelta +from typing import Dict, Optional + +from langchain.schema import ChatMessage + +from khoj.database.models import Agent +from khoj.processor.conversation import prompts +from khoj.processor.conversation.anthropic.utils import ( + anthropic_chat_completion_with_backoff, + anthropic_completion_with_backoff, +) +from khoj.processor.conversation.utils import generate_chatml_messages_with_context +from khoj.utils.helpers import ConversationCommand, is_none_or_empty +from khoj.utils.rawconfig import LocationData + +logger = logging.getLogger(__name__) + + +def extract_questions_anthropic( + text, + model: Optional[str] = "claude-instant-1.2", + conversation_log={}, + api_key=None, + temperature=0, + max_tokens=100, + location_data: LocationData = None, +): + """ + Infer search queries to retrieve relevant notes to answer user query + """ + # Extract Past User Message and Inferred Questions from Conversation Log + location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + + # Extract Past User Message and Inferred Questions from Conversation Log + chat_history = "".join( + [ + f'Q: {chat["intent"]["query"]}\nKhoj: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n' + for chat in conversation_log.get("chat", [])[-4:] + if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type") + ] + ) + + # Get dates relative to today for prompt creation + today = datetime.today() + current_new_year = today.replace(month=1, day=1) + last_new_year = current_new_year.replace(year=today.year - 1) + + system_prompt = prompts.extract_questions_anthropic_system_prompt.format( + current_date=today.strftime("%Y-%m-%d"), + day_of_week=today.strftime("%A"), + last_new_year=last_new_year.strftime("%Y"), + last_new_year_date=last_new_year.strftime("%Y-%m-%d"), + current_new_year_date=current_new_year.strftime("%Y-%m-%d"), + yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), + location=location, + ) + + prompt = prompts.extract_questions_anthropic_user_message.format( + chat_history=chat_history, + text=text, + ) + + messages = [ChatMessage(content=prompt, role="user")] + + response = anthropic_completion_with_backoff( + messages=messages, + system_prompt=system_prompt, + model_name=model, + temperature=temperature, + api_key=api_key, + max_tokens=max_tokens, + ) + + # Extract, Clean Message from Claude's Response + try: + response = response.strip() + match = re.search(r"\{.*?\}", response) + if match: + response = match.group() + response = json.loads(response) + response = [q.strip() for q in response["queries"] if q.strip()] + if not isinstance(response, list) or not response: + logger.error(f"Invalid response for constructing subqueries: {response}") + return [text] + return response + except: + logger.warning(f"Claude returned invalid JSON. Falling back to using user message as search query.\n{response}") + questions = [text] + logger.debug(f"Extracted Questions by Claude: {questions}") + return questions + + +def anthropic_send_message_to_model(messages, api_key, model): + """ + Send message to model + """ + # Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter + system_prompt = None + + if len(messages) == 1: + messages[0].role = "user" + else: + system_prompt = "" + for message in messages.copy(): + if message.role == "system": + system_prompt += message.content + messages.remove(message) + + # Get Response from GPT. Don't use response_type because Anthropic doesn't support it. + return anthropic_completion_with_backoff( + messages=messages, + system_prompt=system_prompt, + model_name=model, + api_key=api_key, + ) + + +def converse_anthropic( + references, + user_query, + online_results: Optional[Dict[str, Dict]] = None, + conversation_log={}, + model: Optional[str] = "claude-instant-1.2", + api_key: Optional[str] = None, + completion_func=None, + conversation_commands=[ConversationCommand.Default], + max_prompt_size=None, + tokenizer_name=None, + location_data: LocationData = None, + user_name: str = None, + agent: Agent = None, +): + """ + Converse with user using Anthropic's Claude + """ + # Initialize Variables + current_date = datetime.now().strftime("%Y-%m-%d") + compiled_references = "\n\n".join({f"# {item}" for item in references}) + + conversation_primer = prompts.query_prompt.format(query=user_query) + + if agent and agent.personality: + system_prompt = prompts.custom_personality.format( + name=agent.name, bio=agent.personality, current_date=current_date + ) + else: + system_prompt = prompts.personality.format(current_date=current_date) + + if location_data: + location = f"{location_data.city}, {location_data.region}, {location_data.country}" + location_prompt = prompts.user_location.format(location=location) + system_prompt = f"{system_prompt}\n{location_prompt}" + + if user_name: + user_name_prompt = prompts.user_name.format(name=user_name) + system_prompt = f"{system_prompt}\n{user_name_prompt}" + + # Get Conversation Primer appropriate to Conversation Type + if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references): + completion_func(chat_response=prompts.no_notes_found.format()) + return iter([prompts.no_notes_found.format()]) + elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results): + completion_func(chat_response=prompts.no_online_results_found.format()) + return iter([prompts.no_online_results_found.format()]) + + if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands: + conversation_primer = ( + f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}" + ) + if not is_none_or_empty(compiled_references): + conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}" + + # Setup Prompt with Primer or Conversation History + messages = generate_chatml_messages_with_context( + conversation_primer, + conversation_log=conversation_log, + model_name=model, + max_prompt_size=max_prompt_size, + tokenizer_name=tokenizer_name, + ) + + for message in messages.copy(): + if message.role == "system": + system_prompt += message.content + messages.remove(message) + + truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) + logger.debug(f"Conversation Context for Claude: {truncated_messages}") + + # Get Response from Claude + return anthropic_chat_completion_with_backoff( + messages=messages, + compiled_references=references, + online_results=online_results, + model_name=model, + temperature=0, + api_key=api_key, + system_prompt=system_prompt, + completion_func=completion_func, + max_prompt_size=max_prompt_size, + ) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py new file mode 100644 index 00000000..d6c8f4f2 --- /dev/null +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -0,0 +1,116 @@ +import logging +from threading import Thread +from typing import Dict, List + +import anthropic +from tenacity import ( + before_sleep_log, + retry, + stop_after_attempt, + wait_exponential, + wait_random_exponential, +) + +from khoj.processor.conversation.utils import ThreadedGenerator + +logger = logging.getLogger(__name__) + +anthropic_clients: Dict[str, anthropic.Anthropic] = {} + + +DEFAULT_MAX_TOKENS_ANTHROPIC = 3000 + + +@retry( + wait=wait_random_exponential(min=1, max=10), + stop=stop_after_attempt(2), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) +def anthropic_completion_with_backoff( + messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None +) -> str: + if api_key not in anthropic_clients: + client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) + anthropic_clients[api_key] = client + else: + client = anthropic_clients[api_key] + + formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + + aggregated_response = "" + max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC + + model_kwargs = model_kwargs or dict() + if system_prompt: + model_kwargs["system"] = system_prompt + + with client.messages.stream( + messages=formatted_messages, + model=model_name, # type: ignore + temperature=temperature, + timeout=20, + max_tokens=max_tokens, + **(model_kwargs), + ) as stream: + for text in stream.text_stream: + aggregated_response += text + + return aggregated_response + + +@retry( + wait=wait_exponential(multiplier=1, min=4, max=10), + stop=stop_after_attempt(2), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) +def anthropic_chat_completion_with_backoff( + messages, + compiled_references, + online_results, + model_name, + temperature, + api_key, + system_prompt, + max_prompt_size=None, + completion_func=None, + model_kwargs=None, +): + g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) + t = Thread( + target=anthropic_llm_thread, + args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs), + ) + t.start() + return g + + +def anthropic_llm_thread( + g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None +): + if api_key not in anthropic_clients: + client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) + anthropic_clients[api_key] = client + else: + client: anthropic.Anthropic = anthropic_clients[api_key] + + formatted_messages: List[anthropic.types.MessageParam] = [ + anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages + ] + + max_prompt_size = max_prompt_size or DEFAULT_MAX_TOKENS_ANTHROPIC + + with client.messages.stream( + messages=formatted_messages, + model=model_name, # type: ignore + temperature=temperature, + system=system_prompt, + timeout=20, + max_tokens=max_prompt_size, + **(model_kwargs or dict()), + ) as stream: + for text in stream.text_stream: + g.send(text) + + g.close() diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index bccd7719..d0f5356d 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -261,6 +261,45 @@ Khoj: """.strip() ) +extract_questions_anthropic_system_prompt = PromptTemplate.from_template( + """ +You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes. Disregard online search requests. Construct search queries to retrieve relevant information to answer the user's question. +- You will be provided past questions(Q) and answers(A) for context. +- Add as much context from the previous questions and answers as required into your search queries. +- Break messages into multiple search queries when required to retrieve the relevant information. +- Add date filters to your search queries from questions and answers when required to retrieve the relevant information. + +What searches will you perform to answer the users question? Respond with a JSON object with the key "queries" mapping to a list of searches you would perform on the user's knowledge base. Just return the queries and nothing else. + +Current Date: {day_of_week}, {current_date} +User's Location: {location} + +Here are some examples of how you can construct search queries to answer the user's question: + +User: How was my trip to Cambodia? +Assistant: {{"queries": ["How was my trip to Cambodia?"]}} + +User: What national parks did I go to last year? +Assistant: {{"queries": ["National park I visited in {last_new_year} dt>='{last_new_year_date}' dt<'{current_new_year_date}'"]}} + +User: How can you help me? +Assistant: {{"queries": ["Social relationships", "Physical and mental health", "Education and career", "Personal life goals and habits"]}} + +User: Who all did I meet here yesterday? +Assistant: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}} +""".strip() +) + +extract_questions_anthropic_user_message = PromptTemplate.from_template( + """ +Here's our most recent chat history: +{chat_history} + +User: {text} +Assistant: +""".strip() +) + system_prompt_extract_relevant_information = """As a professional analyst, create a comprehensive report of the most relevant information from a web page in response to a user's query. The text provided is directly from within the web page. The report you create should be multiple paragraphs, and it should represent the content of the website. Tell the user exactly what the website says in response to their query, while adhering to these guidelines: 1. Answer the user's query as specifically as possible. Include many supporting details from the website. diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 4cdb53e4..ff1217f9 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -27,6 +27,9 @@ from khoj.database.adapters import ( get_user_search_model_or_default, ) from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOptions +from khoj.processor.conversation.anthropic.anthropic_chat import ( + extract_questions_anthropic, +) from khoj.processor.conversation.offline.chat_model import extract_questions_offline from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.openai.gpt import extract_questions @@ -37,7 +40,6 @@ from khoj.routers.helpers import ( ConversationCommandRateLimiter, acreate_title_from_query, schedule_automation, - scheduled_chat, update_telemetry_state, ) from khoj.search_filter.date_filter import DateFilter @@ -340,6 +342,18 @@ async def extract_references_and_questions( api_key=api_key, conversation_log=meta_log, location_data=location_data, + max_tokens=conversation_config.max_prompt_size, + ) + elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: + api_key = conversation_config.openai_config.api_key + chat_model = conversation_config.chat_model + inferred_queries = extract_questions_anthropic( + defiltered_query, + model=chat_model, + api_key=api_key, + conversation_log=meta_log, + location_data=location_data, + max_tokens=conversation_config.max_prompt_size, ) # Collate search results as context for GPT diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 5c1dc25d..51106408 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -53,6 +53,10 @@ from khoj.database.models import ( UserRequests, ) from khoj.processor.conversation import prompts +from khoj.processor.conversation.anthropic.anthropic_chat import ( + anthropic_send_message_to_model, + converse_anthropic, +) from khoj.processor.conversation.offline.chat_model import ( converse_offline, send_message_to_model_offline, @@ -113,7 +117,7 @@ async def is_ready_to_chat(user: KhojUser): if ( user_conversation_config - and user_conversation_config.model_type == "openai" + and (user_conversation_config.model_type == "openai" or user_conversation_config.model_type == "anthropic") and user_conversation_config.openai_config ): return True @@ -508,6 +512,21 @@ async def send_message_to_model_wrapper( ) return openai_response + elif conversation_config.model_type == "anthropic": + api_key = conversation_config.openai_config.api_key + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model, + max_prompt_size=max_tokens, + tokenizer_name=tokenizer, + ) + + return anthropic_send_message_to_model( + messages=truncated_messages, + api_key=api_key, + model=chat_model, + ) else: raise HTTPException(status_code=500, detail="Invalid conversation config") @@ -542,8 +561,7 @@ def send_message_to_model_wrapper_sync( ) elif conversation_config.model_type == "openai": - openai_chat_config = ConversationAdapters.get_openai_conversation_config() - api_key = openai_chat_config.api_key + api_key = conversation_config.openai_config.api_key truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, model_name=chat_model ) @@ -553,6 +571,21 @@ def send_message_to_model_wrapper_sync( ) return openai_response + + elif conversation_config.model_type == "anthropic": + api_key = conversation_config.openai_config.api_key + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model, + max_prompt_size=max_tokens, + ) + + return anthropic_send_message_to_model( + messages=truncated_messages, + api_key=api_key, + model=chat_model, + ) else: raise HTTPException(status_code=500, detail="Invalid conversation config") @@ -631,6 +664,24 @@ def generate_chat_response( agent=agent, ) + elif conversation_config.model_type == "anthropic": + api_key = conversation_config.openai_config.api_key + chat_response = converse_anthropic( + compiled_references, + q, + online_results, + meta_log, + model=conversation_config.chat_model, + api_key=api_key, + completion_func=partial_completion, + conversation_commands=conversation_commands, + max_prompt_size=conversation_config.max_prompt_size, + tokenizer_name=conversation_config.tokenizer, + location_data=location_data, + user_name=user_name, + agent=agent, + ) + metadata.update({"chat_model": conversation_config.chat_model}) except Exception as e: