Add support for Anthropic models (#760)

* Add support for chatting with Anthropic's suite of models

- Had to use a custom class because there was enough nuance with how the anthropic SDK works that it would be better to simply separate out the logic. The extract questions flow needed modification of the system prompt in order to work as intended with the haiku model
This commit is contained in:
sabaimran 2024-05-26 22:50:34 +05:30 committed by GitHub
parent e2922968d6
commit 01cdc54ad0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 454 additions and 5 deletions

View file

@ -85,6 +85,7 @@ dependencies = [
"pytz ~= 2024.1",
"cron-descriptor == 1.4.3",
"django_apscheduler == 0.6.2",
"anthropic == 0.26.1",
]
dynamic = ["version"]

View file

@ -833,7 +833,9 @@ class ConversationAdapters:
return conversation_config
if conversation_config.model_type == "openai" and conversation_config.openai_config:
if (
conversation_config.model_type == "openai" or conversation_config.model_type == "anthropic"
) and conversation_config.openai_config:
return conversation_config
else:

View file

@ -0,0 +1,21 @@
# Generated by Django 4.2.10 on 2024-05-26 12:35
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0042_serverchatsettings"),
]
operations = [
migrations.AlterField(
model_name="chatmodeloptions",
name="model_type",
field=models.CharField(
choices=[("openai", "Openai"), ("offline", "Offline"), ("anthropic", "Anthropic")],
default="offline",
max_length=200,
),
),
]

View file

@ -84,6 +84,7 @@ class ChatModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
ANTHROPIC = "anthropic"
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)

View file

@ -0,0 +1,204 @@
import json
import logging
import re
from datetime import datetime, timedelta
from typing import Dict, Optional
from langchain.schema import ChatMessage
from khoj.database.models import Agent
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
anthropic_completion_with_backoff,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
def extract_questions_anthropic(
text,
model: Optional[str] = "claude-instant-1.2",
conversation_log={},
api_key=None,
temperature=0,
max_tokens=100,
location_data: LocationData = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
[
f'Q: {chat["intent"]["query"]}\nKhoj: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type")
]
)
# Get dates relative to today for prompt creation
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
)
prompt = prompts.extract_questions_anthropic_user_message.format(
chat_history=chat_history,
text=text,
)
messages = [ChatMessage(content=prompt, role="user")]
response = anthropic_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
temperature=temperature,
api_key=api_key,
max_tokens=max_tokens,
)
# Extract, Clean Message from Claude's Response
try:
response = response.strip()
match = re.search(r"\{.*?\}", response)
if match:
response = match.group()
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [text]
return response
except:
logger.warning(f"Claude returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text]
logger.debug(f"Extracted Questions by Claude: {questions}")
return questions
def anthropic_send_message_to_model(messages, api_key, model):
"""
Send message to model
"""
# Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter
system_prompt = None
if len(messages) == 1:
messages[0].role = "user"
else:
system_prompt = ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
# Get Response from GPT. Don't use response_type because Anthropic doesn't support it.
return anthropic_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
)
def converse_anthropic(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: Optional[str] = "claude-instant-1.2",
api_key: Optional[str] = None,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
):
"""
Converse with user using Anthropic's Claude
"""
# Initialize Variables
current_date = datetime.now().strftime("%Y-%m-%d")
compiled_references = "\n\n".join({f"# {item}" for item in references})
conversation_primer = prompts.query_prompt.format(query=user_query)
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
name=agent.name, bio=agent.personality, current_date=current_date
)
else:
system_prompt = prompts.personality.format(current_date=current_date)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
conversation_primer = (
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
)
if not is_none_or_empty(compiled_references):
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
conversation_log=conversation_log,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
)
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for Claude: {truncated_messages}")
# Get Response from Claude
return anthropic_chat_completion_with_backoff(
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,
temperature=0,
api_key=api_key,
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,
)

View file

@ -0,0 +1,116 @@
import logging
from threading import Thread
from typing import Dict, List
import anthropic
from tenacity import (
before_sleep_log,
retry,
stop_after_attempt,
wait_exponential,
wait_random_exponential,
)
from khoj.processor.conversation.utils import ThreadedGenerator
logger = logging.getLogger(__name__)
anthropic_clients: Dict[str, anthropic.Anthropic] = {}
DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def anthropic_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
) -> str:
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
anthropic_clients[api_key] = client
else:
client = anthropic_clients[api_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
aggregated_response = ""
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
model_kwargs = model_kwargs or dict()
if system_prompt:
model_kwargs["system"] = system_prompt
with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
timeout=20,
max_tokens=max_tokens,
**(model_kwargs),
) as stream:
for text in stream.text_stream:
aggregated_response += text
return aggregated_response
@retry(
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def anthropic_chat_completion_with_backoff(
messages,
compiled_references,
online_results,
model_name,
temperature,
api_key,
system_prompt,
max_prompt_size=None,
completion_func=None,
model_kwargs=None,
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=anthropic_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
)
t.start()
return g
def anthropic_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
):
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
anthropic_clients[api_key] = client
else:
client: anthropic.Anthropic = anthropic_clients[api_key]
formatted_messages: List[anthropic.types.MessageParam] = [
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
]
max_prompt_size = max_prompt_size or DEFAULT_MAX_TOKENS_ANTHROPIC
with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
system=system_prompt,
timeout=20,
max_tokens=max_prompt_size,
**(model_kwargs or dict()),
) as stream:
for text in stream.text_stream:
g.send(text)
g.close()

View file

@ -261,6 +261,45 @@ Khoj:
""".strip()
)
extract_questions_anthropic_system_prompt = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes. Disregard online search requests. Construct search queries to retrieve relevant information to answer the user's question.
- You will be provided past questions(Q) and answers(A) for context.
- Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
What searches will you perform to answer the users question? Respond with a JSON object with the key "queries" mapping to a list of searches you would perform on the user's knowledge base. Just return the queries and nothing else.
Current Date: {day_of_week}, {current_date}
User's Location: {location}
Here are some examples of how you can construct search queries to answer the user's question:
User: How was my trip to Cambodia?
Assistant: {{"queries": ["How was my trip to Cambodia?"]}}
User: What national parks did I go to last year?
Assistant: {{"queries": ["National park I visited in {last_new_year} dt>='{last_new_year_date}' dt<'{current_new_year_date}'"]}}
User: How can you help me?
Assistant: {{"queries": ["Social relationships", "Physical and mental health", "Education and career", "Personal life goals and habits"]}}
User: Who all did I meet here yesterday?
Assistant: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
""".strip()
)
extract_questions_anthropic_user_message = PromptTemplate.from_template(
"""
Here's our most recent chat history:
{chat_history}
User: {text}
Assistant:
""".strip()
)
system_prompt_extract_relevant_information = """As a professional analyst, create a comprehensive report of the most relevant information from a web page in response to a user's query. The text provided is directly from within the web page. The report you create should be multiple paragraphs, and it should represent the content of the website. Tell the user exactly what the website says in response to their query, while adhering to these guidelines:
1. Answer the user's query as specifically as possible. Include many supporting details from the website.

View file

@ -27,6 +27,9 @@ from khoj.database.adapters import (
get_user_search_model_or_default,
)
from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOptions
from khoj.processor.conversation.anthropic.anthropic_chat import (
extract_questions_anthropic,
)
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions
@ -37,7 +40,6 @@ from khoj.routers.helpers import (
ConversationCommandRateLimiter,
acreate_title_from_query,
schedule_automation,
scheduled_chat,
update_telemetry_state,
)
from khoj.search_filter.date_filter import DateFilter
@ -340,6 +342,18 @@ async def extract_references_and_questions(
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
max_tokens=conversation_config.max_prompt_size,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
chat_model = conversation_config.chat_model
inferred_queries = extract_questions_anthropic(
defiltered_query,
model=chat_model,
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
max_tokens=conversation_config.max_prompt_size,
)
# Collate search results as context for GPT

View file

@ -53,6 +53,10 @@ from khoj.database.models import (
UserRequests,
)
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
anthropic_send_message_to_model,
converse_anthropic,
)
from khoj.processor.conversation.offline.chat_model import (
converse_offline,
send_message_to_model_offline,
@ -113,7 +117,7 @@ async def is_ready_to_chat(user: KhojUser):
if (
user_conversation_config
and user_conversation_config.model_type == "openai"
and (user_conversation_config.model_type == "openai" or user_conversation_config.model_type == "anthropic")
and user_conversation_config.openai_config
):
return True
@ -508,6 +512,21 @@ async def send_message_to_model_wrapper(
)
return openai_response
elif conversation_config.model_type == "anthropic":
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
)
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -542,8 +561,7 @@ def send_message_to_model_wrapper_sync(
)
elif conversation_config.model_type == "openai":
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
api_key = openai_chat_config.api_key
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message, system_message=system_message, model_name=chat_model
)
@ -553,6 +571,21 @@ def send_message_to_model_wrapper_sync(
)
return openai_response
elif conversation_config.model_type == "anthropic":
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
)
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -631,6 +664,24 @@ def generate_chat_response(
agent=agent,
)
elif conversation_config.model_type == "anthropic":
api_key = conversation_config.openai_config.api_key
chat_response = converse_anthropic(
compiled_references,
q,
online_results,
meta_log,
model=conversation_config.chat_model,
api_key=api_key,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
)
metadata.update({"chat_model": conversation_config.chat_model})
except Exception as e: