mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Add support for Anthropic models (#760)
* Add support for chatting with Anthropic's suite of models - Had to use a custom class because there was enough nuance with how the anthropic SDK works that it would be better to simply separate out the logic. The extract questions flow needed modification of the system prompt in order to work as intended with the haiku model
This commit is contained in:
parent
e2922968d6
commit
01cdc54ad0
10 changed files with 454 additions and 5 deletions
|
@ -85,6 +85,7 @@ dependencies = [
|
||||||
"pytz ~= 2024.1",
|
"pytz ~= 2024.1",
|
||||||
"cron-descriptor == 1.4.3",
|
"cron-descriptor == 1.4.3",
|
||||||
"django_apscheduler == 0.6.2",
|
"django_apscheduler == 0.6.2",
|
||||||
|
"anthropic == 0.26.1",
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
|
|
@ -833,7 +833,9 @@ class ConversationAdapters:
|
||||||
|
|
||||||
return conversation_config
|
return conversation_config
|
||||||
|
|
||||||
if conversation_config.model_type == "openai" and conversation_config.openai_config:
|
if (
|
||||||
|
conversation_config.model_type == "openai" or conversation_config.model_type == "anthropic"
|
||||||
|
) and conversation_config.openai_config:
|
||||||
return conversation_config
|
return conversation_config
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Generated by Django 4.2.10 on 2024-05-26 12:35
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0042_serverchatsettings"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="chatmodeloptions",
|
||||||
|
name="model_type",
|
||||||
|
field=models.CharField(
|
||||||
|
choices=[("openai", "Openai"), ("offline", "Offline"), ("anthropic", "Anthropic")],
|
||||||
|
default="offline",
|
||||||
|
max_length=200,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -84,6 +84,7 @@ class ChatModelOptions(BaseModel):
|
||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
OFFLINE = "offline"
|
OFFLINE = "offline"
|
||||||
|
ANTHROPIC = "anthropic"
|
||||||
|
|
||||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
|
0
src/khoj/processor/conversation/anthropic/__init__.py
Normal file
0
src/khoj/processor/conversation/anthropic/__init__.py
Normal file
204
src/khoj/processor/conversation/anthropic/anthropic_chat.py
Normal file
204
src/khoj/processor/conversation/anthropic/anthropic_chat.py
Normal file
|
@ -0,0 +1,204 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
|
from khoj.database.models import Agent
|
||||||
|
from khoj.processor.conversation import prompts
|
||||||
|
from khoj.processor.conversation.anthropic.utils import (
|
||||||
|
anthropic_chat_completion_with_backoff,
|
||||||
|
anthropic_completion_with_backoff,
|
||||||
|
)
|
||||||
|
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||||
|
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||||
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_questions_anthropic(
|
||||||
|
text,
|
||||||
|
model: Optional[str] = "claude-instant-1.2",
|
||||||
|
conversation_log={},
|
||||||
|
api_key=None,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=100,
|
||||||
|
location_data: LocationData = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
"""
|
||||||
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
|
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||||
|
|
||||||
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
|
chat_history = "".join(
|
||||||
|
[
|
||||||
|
f'Q: {chat["intent"]["query"]}\nKhoj: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
|
||||||
|
for chat in conversation_log.get("chat", [])[-4:]
|
||||||
|
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get dates relative to today for prompt creation
|
||||||
|
today = datetime.today()
|
||||||
|
current_new_year = today.replace(month=1, day=1)
|
||||||
|
last_new_year = current_new_year.replace(year=today.year - 1)
|
||||||
|
|
||||||
|
system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
|
||||||
|
current_date=today.strftime("%Y-%m-%d"),
|
||||||
|
day_of_week=today.strftime("%A"),
|
||||||
|
last_new_year=last_new_year.strftime("%Y"),
|
||||||
|
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
|
||||||
|
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||||
|
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||||
|
location=location,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = prompts.extract_questions_anthropic_user_message.format(
|
||||||
|
chat_history=chat_history,
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [ChatMessage(content=prompt, role="user")]
|
||||||
|
|
||||||
|
response = anthropic_completion_with_backoff(
|
||||||
|
messages=messages,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
model_name=model,
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=api_key,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract, Clean Message from Claude's Response
|
||||||
|
try:
|
||||||
|
response = response.strip()
|
||||||
|
match = re.search(r"\{.*?\}", response)
|
||||||
|
if match:
|
||||||
|
response = match.group()
|
||||||
|
response = json.loads(response)
|
||||||
|
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||||
|
if not isinstance(response, list) or not response:
|
||||||
|
logger.error(f"Invalid response for constructing subqueries: {response}")
|
||||||
|
return [text]
|
||||||
|
return response
|
||||||
|
except:
|
||||||
|
logger.warning(f"Claude returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||||
|
questions = [text]
|
||||||
|
logger.debug(f"Extracted Questions by Claude: {questions}")
|
||||||
|
return questions
|
||||||
|
|
||||||
|
|
||||||
|
def anthropic_send_message_to_model(messages, api_key, model):
|
||||||
|
"""
|
||||||
|
Send message to model
|
||||||
|
"""
|
||||||
|
# Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter
|
||||||
|
system_prompt = None
|
||||||
|
|
||||||
|
if len(messages) == 1:
|
||||||
|
messages[0].role = "user"
|
||||||
|
else:
|
||||||
|
system_prompt = ""
|
||||||
|
for message in messages.copy():
|
||||||
|
if message.role == "system":
|
||||||
|
system_prompt += message.content
|
||||||
|
messages.remove(message)
|
||||||
|
|
||||||
|
# Get Response from GPT. Don't use response_type because Anthropic doesn't support it.
|
||||||
|
return anthropic_completion_with_backoff(
|
||||||
|
messages=messages,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
model_name=model,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def converse_anthropic(
|
||||||
|
references,
|
||||||
|
user_query,
|
||||||
|
online_results: Optional[Dict[str, Dict]] = None,
|
||||||
|
conversation_log={},
|
||||||
|
model: Optional[str] = "claude-instant-1.2",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
completion_func=None,
|
||||||
|
conversation_commands=[ConversationCommand.Default],
|
||||||
|
max_prompt_size=None,
|
||||||
|
tokenizer_name=None,
|
||||||
|
location_data: LocationData = None,
|
||||||
|
user_name: str = None,
|
||||||
|
agent: Agent = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Converse with user using Anthropic's Claude
|
||||||
|
"""
|
||||||
|
# Initialize Variables
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
||||||
|
|
||||||
|
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||||
|
|
||||||
|
if agent and agent.personality:
|
||||||
|
system_prompt = prompts.custom_personality.format(
|
||||||
|
name=agent.name, bio=agent.personality, current_date=current_date
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
system_prompt = prompts.personality.format(current_date=current_date)
|
||||||
|
|
||||||
|
if location_data:
|
||||||
|
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||||
|
location_prompt = prompts.user_location.format(location=location)
|
||||||
|
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||||
|
|
||||||
|
if user_name:
|
||||||
|
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||||
|
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||||
|
|
||||||
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||||
|
completion_func(chat_response=prompts.no_notes_found.format())
|
||||||
|
return iter([prompts.no_notes_found.format()])
|
||||||
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||||
|
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||||
|
return iter([prompts.no_online_results_found.format()])
|
||||||
|
|
||||||
|
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||||
|
conversation_primer = (
|
||||||
|
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
|
||||||
|
)
|
||||||
|
if not is_none_or_empty(compiled_references):
|
||||||
|
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
|
||||||
|
|
||||||
|
# Setup Prompt with Primer or Conversation History
|
||||||
|
messages = generate_chatml_messages_with_context(
|
||||||
|
conversation_primer,
|
||||||
|
conversation_log=conversation_log,
|
||||||
|
model_name=model,
|
||||||
|
max_prompt_size=max_prompt_size,
|
||||||
|
tokenizer_name=tokenizer_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
for message in messages.copy():
|
||||||
|
if message.role == "system":
|
||||||
|
system_prompt += message.content
|
||||||
|
messages.remove(message)
|
||||||
|
|
||||||
|
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
||||||
|
logger.debug(f"Conversation Context for Claude: {truncated_messages}")
|
||||||
|
|
||||||
|
# Get Response from Claude
|
||||||
|
return anthropic_chat_completion_with_backoff(
|
||||||
|
messages=messages,
|
||||||
|
compiled_references=references,
|
||||||
|
online_results=online_results,
|
||||||
|
model_name=model,
|
||||||
|
temperature=0,
|
||||||
|
api_key=api_key,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
completion_func=completion_func,
|
||||||
|
max_prompt_size=max_prompt_size,
|
||||||
|
)
|
116
src/khoj/processor/conversation/anthropic/utils.py
Normal file
116
src/khoj/processor/conversation/anthropic/utils.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
import logging
|
||||||
|
from threading import Thread
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import anthropic
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
wait_random_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
anthropic_clients: Dict[str, anthropic.Anthropic] = {}
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
|
stop=stop_after_attempt(2),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def anthropic_completion_with_backoff(
|
||||||
|
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
|
||||||
|
) -> str:
|
||||||
|
if api_key not in anthropic_clients:
|
||||||
|
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||||
|
anthropic_clients[api_key] = client
|
||||||
|
else:
|
||||||
|
client = anthropic_clients[api_key]
|
||||||
|
|
||||||
|
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
|
|
||||||
|
aggregated_response = ""
|
||||||
|
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||||
|
|
||||||
|
model_kwargs = model_kwargs or dict()
|
||||||
|
if system_prompt:
|
||||||
|
model_kwargs["system"] = system_prompt
|
||||||
|
|
||||||
|
with client.messages.stream(
|
||||||
|
messages=formatted_messages,
|
||||||
|
model=model_name, # type: ignore
|
||||||
|
temperature=temperature,
|
||||||
|
timeout=20,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
**(model_kwargs),
|
||||||
|
) as stream:
|
||||||
|
for text in stream.text_stream:
|
||||||
|
aggregated_response += text
|
||||||
|
|
||||||
|
return aggregated_response
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
stop=stop_after_attempt(2),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def anthropic_chat_completion_with_backoff(
|
||||||
|
messages,
|
||||||
|
compiled_references,
|
||||||
|
online_results,
|
||||||
|
model_name,
|
||||||
|
temperature,
|
||||||
|
api_key,
|
||||||
|
system_prompt,
|
||||||
|
max_prompt_size=None,
|
||||||
|
completion_func=None,
|
||||||
|
model_kwargs=None,
|
||||||
|
):
|
||||||
|
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||||
|
t = Thread(
|
||||||
|
target=anthropic_llm_thread,
|
||||||
|
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def anthropic_llm_thread(
|
||||||
|
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
|
||||||
|
):
|
||||||
|
if api_key not in anthropic_clients:
|
||||||
|
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||||
|
anthropic_clients[api_key] = client
|
||||||
|
else:
|
||||||
|
client: anthropic.Anthropic = anthropic_clients[api_key]
|
||||||
|
|
||||||
|
formatted_messages: List[anthropic.types.MessageParam] = [
|
||||||
|
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
|
||||||
|
]
|
||||||
|
|
||||||
|
max_prompt_size = max_prompt_size or DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||||
|
|
||||||
|
with client.messages.stream(
|
||||||
|
messages=formatted_messages,
|
||||||
|
model=model_name, # type: ignore
|
||||||
|
temperature=temperature,
|
||||||
|
system=system_prompt,
|
||||||
|
timeout=20,
|
||||||
|
max_tokens=max_prompt_size,
|
||||||
|
**(model_kwargs or dict()),
|
||||||
|
) as stream:
|
||||||
|
for text in stream.text_stream:
|
||||||
|
g.send(text)
|
||||||
|
|
||||||
|
g.close()
|
|
@ -261,6 +261,45 @@ Khoj:
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extract_questions_anthropic_system_prompt = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes. Disregard online search requests. Construct search queries to retrieve relevant information to answer the user's question.
|
||||||
|
- You will be provided past questions(Q) and answers(A) for context.
|
||||||
|
- Add as much context from the previous questions and answers as required into your search queries.
|
||||||
|
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||||
|
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||||
|
|
||||||
|
What searches will you perform to answer the users question? Respond with a JSON object with the key "queries" mapping to a list of searches you would perform on the user's knowledge base. Just return the queries and nothing else.
|
||||||
|
|
||||||
|
Current Date: {day_of_week}, {current_date}
|
||||||
|
User's Location: {location}
|
||||||
|
|
||||||
|
Here are some examples of how you can construct search queries to answer the user's question:
|
||||||
|
|
||||||
|
User: How was my trip to Cambodia?
|
||||||
|
Assistant: {{"queries": ["How was my trip to Cambodia?"]}}
|
||||||
|
|
||||||
|
User: What national parks did I go to last year?
|
||||||
|
Assistant: {{"queries": ["National park I visited in {last_new_year} dt>='{last_new_year_date}' dt<'{current_new_year_date}'"]}}
|
||||||
|
|
||||||
|
User: How can you help me?
|
||||||
|
Assistant: {{"queries": ["Social relationships", "Physical and mental health", "Education and career", "Personal life goals and habits"]}}
|
||||||
|
|
||||||
|
User: Who all did I meet here yesterday?
|
||||||
|
Assistant: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
extract_questions_anthropic_user_message = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
Here's our most recent chat history:
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
User: {text}
|
||||||
|
Assistant:
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
system_prompt_extract_relevant_information = """As a professional analyst, create a comprehensive report of the most relevant information from a web page in response to a user's query. The text provided is directly from within the web page. The report you create should be multiple paragraphs, and it should represent the content of the website. Tell the user exactly what the website says in response to their query, while adhering to these guidelines:
|
system_prompt_extract_relevant_information = """As a professional analyst, create a comprehensive report of the most relevant information from a web page in response to a user's query. The text provided is directly from within the web page. The report you create should be multiple paragraphs, and it should represent the content of the website. Tell the user exactly what the website says in response to their query, while adhering to these guidelines:
|
||||||
|
|
||||||
1. Answer the user's query as specifically as possible. Include many supporting details from the website.
|
1. Answer the user's query as specifically as possible. Include many supporting details from the website.
|
||||||
|
|
|
@ -27,6 +27,9 @@ from khoj.database.adapters import (
|
||||||
get_user_search_model_or_default,
|
get_user_search_model_or_default,
|
||||||
)
|
)
|
||||||
from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOptions
|
from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOptions
|
||||||
|
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||||
|
extract_questions_anthropic,
|
||||||
|
)
|
||||||
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
|
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
|
||||||
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
||||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
|
@ -37,7 +40,6 @@ from khoj.routers.helpers import (
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
acreate_title_from_query,
|
acreate_title_from_query,
|
||||||
schedule_automation,
|
schedule_automation,
|
||||||
scheduled_chat,
|
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
)
|
)
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
|
@ -340,6 +342,18 @@ async def extract_references_and_questions(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
|
max_tokens=conversation_config.max_prompt_size,
|
||||||
|
)
|
||||||
|
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
chat_model = conversation_config.chat_model
|
||||||
|
inferred_queries = extract_questions_anthropic(
|
||||||
|
defiltered_query,
|
||||||
|
model=chat_model,
|
||||||
|
api_key=api_key,
|
||||||
|
conversation_log=meta_log,
|
||||||
|
location_data=location_data,
|
||||||
|
max_tokens=conversation_config.max_prompt_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collate search results as context for GPT
|
# Collate search results as context for GPT
|
||||||
|
|
|
@ -53,6 +53,10 @@ from khoj.database.models import (
|
||||||
UserRequests,
|
UserRequests,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
|
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||||
|
anthropic_send_message_to_model,
|
||||||
|
converse_anthropic,
|
||||||
|
)
|
||||||
from khoj.processor.conversation.offline.chat_model import (
|
from khoj.processor.conversation.offline.chat_model import (
|
||||||
converse_offline,
|
converse_offline,
|
||||||
send_message_to_model_offline,
|
send_message_to_model_offline,
|
||||||
|
@ -113,7 +117,7 @@ async def is_ready_to_chat(user: KhojUser):
|
||||||
|
|
||||||
if (
|
if (
|
||||||
user_conversation_config
|
user_conversation_config
|
||||||
and user_conversation_config.model_type == "openai"
|
and (user_conversation_config.model_type == "openai" or user_conversation_config.model_type == "anthropic")
|
||||||
and user_conversation_config.openai_config
|
and user_conversation_config.openai_config
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
@ -508,6 +512,21 @@ async def send_message_to_model_wrapper(
|
||||||
)
|
)
|
||||||
|
|
||||||
return openai_response
|
return openai_response
|
||||||
|
elif conversation_config.model_type == "anthropic":
|
||||||
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
tokenizer_name=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return anthropic_send_message_to_model(
|
||||||
|
messages=truncated_messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model=chat_model,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
|
||||||
|
@ -542,8 +561,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == "openai":
|
elif conversation_config.model_type == "openai":
|
||||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
api_key = conversation_config.openai_config.api_key
|
||||||
api_key = openai_chat_config.api_key
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message, system_message=system_message, model_name=chat_model
|
user_message=message, system_message=system_message, model_name=chat_model
|
||||||
)
|
)
|
||||||
|
@ -553,6 +571,21 @@ def send_message_to_model_wrapper_sync(
|
||||||
)
|
)
|
||||||
|
|
||||||
return openai_response
|
return openai_response
|
||||||
|
|
||||||
|
elif conversation_config.model_type == "anthropic":
|
||||||
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return anthropic_send_message_to_model(
|
||||||
|
messages=truncated_messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model=chat_model,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
|
||||||
|
@ -631,6 +664,24 @@ def generate_chat_response(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif conversation_config.model_type == "anthropic":
|
||||||
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
chat_response = converse_anthropic(
|
||||||
|
compiled_references,
|
||||||
|
q,
|
||||||
|
online_results,
|
||||||
|
meta_log,
|
||||||
|
model=conversation_config.chat_model,
|
||||||
|
api_key=api_key,
|
||||||
|
completion_func=partial_completion,
|
||||||
|
conversation_commands=conversation_commands,
|
||||||
|
max_prompt_size=conversation_config.max_prompt_size,
|
||||||
|
tokenizer_name=conversation_config.tokenizer,
|
||||||
|
location_data=location_data,
|
||||||
|
user_name=user_name,
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
metadata.update({"chat_model": conversation_config.chat_model})
|
metadata.update({"chat_model": conversation_config.chat_model})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Reference in a new issue