mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Support customization of the OpenAI base url in admin settings (#725)
- Allow self-hosted users to customize their open ai base url. This allows you to easily use a proxy service and extend support for other models. - This also includes a migration that associates any existing openai chat model configuration with an openai processor configuration - Make changing model a paid/subscriber feature - Removes usage of langchain's OpenAI wrapper for better control over parsing input/output
This commit is contained in:
parent
49834e3b00
commit
2047b0c973
14 changed files with 219 additions and 100 deletions
|
@ -623,10 +623,6 @@ class ConversationAdapters:
|
|||
def get_openai_conversation_config():
|
||||
return OpenAIProcessorConversationConfig.objects.filter().first()
|
||||
|
||||
@staticmethod
|
||||
async def aget_openai_conversation_config():
|
||||
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
def has_valid_openai_conversation_config():
|
||||
return OpenAIProcessorConversationConfig.objects.filter().exists()
|
||||
|
@ -659,7 +655,7 @@ class ConversationAdapters:
|
|||
|
||||
@staticmethod
|
||||
async def aget_default_conversation_config():
|
||||
return await ChatModelOptions.objects.filter().afirst()
|
||||
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
||||
|
||||
@staticmethod
|
||||
def save_conversation(
|
||||
|
@ -697,29 +693,15 @@ class ConversationAdapters:
|
|||
user_conversation_config.setting = new_config
|
||||
user_conversation_config.save()
|
||||
|
||||
@staticmethod
|
||||
async def get_default_offline_llm():
|
||||
return await ChatModelOptions.objects.filter(model_type="offline").afirst()
|
||||
|
||||
@staticmethod
|
||||
async def aget_user_conversation_config(user: KhojUser):
|
||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
config = (
|
||||
await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst()
|
||||
)
|
||||
if not config:
|
||||
return None
|
||||
return config.setting
|
||||
|
||||
@staticmethod
|
||||
async def has_openai_chat():
|
||||
return await OpenAIProcessorConversationConfig.objects.filter().aexists()
|
||||
|
||||
@staticmethod
|
||||
async def aget_default_openai_llm():
|
||||
return await ChatModelOptions.objects.filter(model_type="openai").afirst()
|
||||
|
||||
@staticmethod
|
||||
async def get_openai_chat_config():
|
||||
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
async def get_speech_to_text_config():
|
||||
return await SpeechToTextModelOptions.objects.filter().afirst()
|
||||
|
@ -744,7 +726,8 @@ class ConversationAdapters:
|
|||
|
||||
@staticmethod
|
||||
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
|
||||
if conversation.agent and conversation.agent.chat_model:
|
||||
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
|
||||
if agent and agent.chat_model:
|
||||
conversation_config = conversation.agent.chat_model
|
||||
else:
|
||||
conversation_config = ConversationAdapters.get_conversation_config(user)
|
||||
|
@ -760,8 +743,7 @@ class ConversationAdapters:
|
|||
|
||||
return conversation_config
|
||||
|
||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
||||
if openai_chat_config and conversation_config.model_type == "openai":
|
||||
if conversation_config.model_type == "openai" and conversation_config.openai_config:
|
||||
return conversation_config
|
||||
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Generated by Django 4.2.10 on 2024-04-24 05:46
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
def attach_openai_config(apps, schema_editor):
|
||||
OpenAIProcessorConversationConfig = apps.get_model("database", "OpenAIProcessorConversationConfig")
|
||||
openai_processor_conversation_config = OpenAIProcessorConversationConfig.objects.first()
|
||||
if openai_processor_conversation_config:
|
||||
ChatModelOptions = apps.get_model("database", "ChatModelOptions")
|
||||
for chat_model_option in ChatModelOptions.objects.all():
|
||||
if chat_model_option.model_type == "openai":
|
||||
chat_model_option.openai_config = openai_processor_conversation_config
|
||||
chat_model_option.save()
|
||||
|
||||
|
||||
def separate_openai_config(apps, schema_editor):
|
||||
pass
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0036_delete_offlinechatprocessorconversationconfig"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="chatmodeloptions",
|
||||
name="openai_config",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="database.openaiprocessorconversationconfig",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="openaiprocessorconversationconfig",
|
||||
name="api_base_url",
|
||||
field=models.URLField(blank=True, default=None, null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="openaiprocessorconversationconfig",
|
||||
name="name",
|
||||
field=models.CharField(default="default", max_length=200),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.RunPython(attach_openai_config, reverse_code=separate_openai_config),
|
||||
]
|
14
src/khoj/database/migrations/0038_merge_20240425_0857.py
Normal file
14
src/khoj/database/migrations/0038_merge_20240425_0857.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# Generated by Django 4.2.10 on 2024-04-25 08:57
|
||||
|
||||
from typing import List
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0037_chatmodeloptions_openai_config_and_more"),
|
||||
("database", "0037_searchmodelconfig_bi_encoder_docs_encode_config_and_more"),
|
||||
]
|
||||
|
||||
operations: List[str] = []
|
|
@ -73,6 +73,12 @@ class Subscription(BaseModel):
|
|||
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
|
||||
|
||||
|
||||
class OpenAIProcessorConversationConfig(BaseModel):
|
||||
name = models.CharField(max_length=200)
|
||||
api_key = models.CharField(max_length=200)
|
||||
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
|
||||
|
||||
|
||||
class ChatModelOptions(BaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
|
@ -82,6 +88,9 @@ class ChatModelOptions(BaseModel):
|
|||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
openai_config = models.ForeignKey(
|
||||
OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
|
||||
)
|
||||
|
||||
|
||||
class Agent(BaseModel):
|
||||
|
@ -211,10 +220,6 @@ class TextToImageModelConfig(BaseModel):
|
|||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||
|
||||
|
||||
class OpenAIProcessorConversationConfig(BaseModel):
|
||||
api_key = models.CharField(max_length=200)
|
||||
|
||||
|
||||
class SpeechToTextModelOptions(BaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
|
|
|
@ -191,9 +191,15 @@
|
|||
</select>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
|
||||
<button id="save-model" class="card-button happy" onclick="updateChatModel()">
|
||||
Save
|
||||
</button>
|
||||
{% else %}
|
||||
<button id="save-model" class="card-button" disabled>
|
||||
Subscribe to use different models
|
||||
</button>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
@ -121,14 +121,16 @@ def migrate_server_pg(args):
|
|||
if openai.get("chat-model") is None:
|
||||
openai["chat-model"] = "gpt-3.5-turbo"
|
||||
|
||||
OpenAIProcessorConversationConfig.objects.create(
|
||||
api_key=openai.get("api-key"),
|
||||
openai_config = OpenAIProcessorConversationConfig.objects.create(
|
||||
api_key=openai.get("api-key"), name="default"
|
||||
)
|
||||
|
||||
ChatModelOptions.objects.create(
|
||||
chat_model=openai.get("chat-model"),
|
||||
tokenizer=processor_conversation.get("tokenizer"),
|
||||
max_prompt_size=processor_conversation.get("max-prompt-size"),
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
openai_config=openai_config,
|
||||
)
|
||||
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
|
|
@ -23,6 +23,7 @@ def extract_questions(
|
|||
model: Optional[str] = "gpt-4-turbo-preview",
|
||||
conversation_log={},
|
||||
api_key=None,
|
||||
api_base_url=None,
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
location_data: LocationData = None,
|
||||
|
@ -64,12 +65,12 @@ def extract_questions(
|
|||
# Get Response from GPT
|
||||
response = completion_with_backoff(
|
||||
messages=messages,
|
||||
completion_kwargs={"temperature": temperature, "max_tokens": max_tokens},
|
||||
model_kwargs={
|
||||
"model_name": model,
|
||||
"openai_api_key": api_key,
|
||||
"model_kwargs": {"response_format": {"type": "json_object"}},
|
||||
},
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_base_url=api_base_url,
|
||||
model_kwargs={"response_format": {"type": "json_object"}},
|
||||
openai_api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -89,7 +90,7 @@ def extract_questions(
|
|||
return questions
|
||||
|
||||
|
||||
def send_message_to_model(messages, api_key, model, response_type="text"):
|
||||
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -97,11 +98,10 @@ def send_message_to_model(messages, api_key, model, response_type="text"):
|
|||
# Get Response from GPT
|
||||
return completion_with_backoff(
|
||||
messages=messages,
|
||||
model_kwargs={
|
||||
"model_name": model,
|
||||
"openai_api_key": api_key,
|
||||
"model_kwargs": {"response_format": {"type": response_type}},
|
||||
},
|
||||
model=model,
|
||||
openai_api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
model_kwargs={"response_format": {"type": response_type}},
|
||||
)
|
||||
|
||||
|
||||
|
@ -112,6 +112,7 @@ def converse(
|
|||
conversation_log={},
|
||||
model: str = "gpt-3.5-turbo",
|
||||
api_key: Optional[str] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
temperature: float = 0.2,
|
||||
completion_func=None,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
|
@ -181,6 +182,7 @@ def converse(
|
|||
model_name=model,
|
||||
temperature=temperature,
|
||||
openai_api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
completion_func=completion_func,
|
||||
model_kwargs={"stop": ["Notes:\n["]},
|
||||
)
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
import logging
|
||||
import os
|
||||
from threading import Thread
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain_openai import ChatOpenAI
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
|
@ -20,14 +17,7 @@ from khoj.processor.conversation.utils import ThreadedGenerator
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
def __init__(self, gen: ThreadedGenerator):
|
||||
super().__init__()
|
||||
self.gen = gen
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
||||
self.gen.send(token)
|
||||
openai_clients: Dict[str, openai.OpenAI] = {}
|
||||
|
||||
|
||||
@retry(
|
||||
|
@ -43,13 +33,37 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(messages, model_kwargs={}, completion_kwargs={}) -> str:
|
||||
if not "openai_api_key" in model_kwargs:
|
||||
model_kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
||||
llm = ChatOpenAI(**model_kwargs, request_timeout=20, max_retries=1)
|
||||
def completion_with_backoff(
|
||||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, max_tokens=None
|
||||
) -> str:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client: openai.OpenAI = openai_clients.get(client_key)
|
||||
if not client:
|
||||
client = openai.OpenAI(
|
||||
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||
base_url=api_base_url,
|
||||
)
|
||||
openai_clients[client_key] = client
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
|
||||
chat = client.chat.completions.create(
|
||||
stream=True,
|
||||
messages=formatted_messages, # type: ignore
|
||||
model=model, # type: ignore
|
||||
temperature=temperature,
|
||||
timeout=20,
|
||||
max_tokens=max_tokens,
|
||||
**(model_kwargs or dict()),
|
||||
)
|
||||
aggregated_response = ""
|
||||
for chunk in llm.stream(messages, **completion_kwargs):
|
||||
aggregated_response += chunk.content
|
||||
for chunk in chat:
|
||||
delta_chunk = chunk.choices[0].delta # type: ignore
|
||||
if isinstance(delta_chunk, str):
|
||||
aggregated_response += delta_chunk
|
||||
elif delta_chunk.content:
|
||||
aggregated_response += delta_chunk.content
|
||||
|
||||
return aggregated_response
|
||||
|
||||
|
||||
|
@ -73,30 +87,45 @@ def chat_completion_with_backoff(
|
|||
model_name,
|
||||
temperature,
|
||||
openai_api_key=None,
|
||||
api_base_url=None,
|
||||
completion_func=None,
|
||||
model_kwargs=None,
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
|
||||
t = Thread(
|
||||
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
|
||||
)
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None):
|
||||
callback_handler = StreamingChatCallbackHandler(g)
|
||||
chat = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callback_manager=BaseCallbackManager([callback_handler]),
|
||||
model_name=model_name, # type: ignore
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
if client_key not in openai_clients:
|
||||
client: openai.OpenAI = openai.OpenAI(
|
||||
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||
base_url=api_base_url,
|
||||
)
|
||||
openai_clients[client_key] = client
|
||||
else:
|
||||
client: openai.OpenAI = openai_clients[client_key]
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
|
||||
chat = client.chat.completions.create(
|
||||
stream=True,
|
||||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
temperature=temperature,
|
||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||
model_kwargs=model_kwargs,
|
||||
request_timeout=20,
|
||||
max_retries=1,
|
||||
client=None,
|
||||
timeout=20,
|
||||
**(model_kwargs or dict()),
|
||||
)
|
||||
|
||||
chat(messages=messages)
|
||||
for chunk in chat:
|
||||
delta_chunk = chunk.choices[0].delta
|
||||
if isinstance(delta_chunk, str):
|
||||
g.send(delta_chunk)
|
||||
elif delta_chunk.content:
|
||||
g.send(delta_chunk.content)
|
||||
|
||||
g.close()
|
||||
|
|
|
@ -14,6 +14,7 @@ from transformers import AutoTokenizer
|
|||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import ClientApplication, KhojUser
|
||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -186,19 +187,31 @@ def truncate_messages(
|
|||
max_prompt_size,
|
||||
model_name: str,
|
||||
loaded_model: Optional[Llama] = None,
|
||||
tokenizer_name="hf-internal-testing/llama-tokenizer",
|
||||
tokenizer_name=None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
|
||||
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
||||
|
||||
try:
|
||||
if loaded_model:
|
||||
encoder = loaded_model.tokenizer()
|
||||
elif model_name.startswith("gpt-"):
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
elif tokenizer_name:
|
||||
if tokenizer_name in state.pretrained_tokenizers:
|
||||
encoder = state.pretrained_tokenizers[tokenizer_name]
|
||||
else:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
state.pretrained_tokenizers[tokenizer_name] = encoder
|
||||
else:
|
||||
encoder = download_model(model_name).tokenizer()
|
||||
except:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
if default_tokenizer in state.pretrained_tokenizers:
|
||||
encoder = state.pretrained_tokenizers[default_tokenizer]
|
||||
else:
|
||||
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
|
||||
state.pretrained_tokenizers[default_tokenizer] = encoder
|
||||
logger.warning(
|
||||
f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||
)
|
||||
|
|
|
@ -267,7 +267,6 @@ async def transcribe(
|
|||
|
||||
async def extract_references_and_questions(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
meta_log: dict,
|
||||
q: str,
|
||||
n: int,
|
||||
|
@ -303,14 +302,12 @@ async def extract_references_and_questions(
|
|||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
||||
if conversation_config is None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
|
||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
using_offline_chat = True
|
||||
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
|
||||
chat_model = default_offline_llm.chat_model
|
||||
max_tokens = default_offline_llm.max_prompt_size
|
||||
chat_model = conversation_config.chat_model
|
||||
max_tokens = conversation_config.max_prompt_size
|
||||
if state.offline_chat_processor_config is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
|
||||
|
@ -324,11 +321,10 @@ async def extract_references_and_questions(
|
|||
location_data=location_data,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
)
|
||||
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
default_openai_llm = await ConversationAdapters.aget_default_openai_llm()
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = conversation_config.openai_config
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model = default_openai_llm.chat_model
|
||||
chat_model = conversation_config.chat_model
|
||||
inferred_queries = extract_questions(
|
||||
defiltered_query,
|
||||
model=chat_model,
|
||||
|
|
|
@ -380,7 +380,7 @@ async def websocket_endpoint(
|
|||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
websocket, None, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update
|
||||
websocket, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update
|
||||
)
|
||||
|
||||
if compiled_references:
|
||||
|
@ -575,7 +575,7 @@ async def chat(
|
|||
user_name = await aget_user_name(user)
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
|
||||
request, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
|
||||
)
|
||||
online_results: Dict[str, Dict] = {}
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from asgiref.sync import sync_to_async
|
|||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response
|
||||
from starlette.authentication import requires
|
||||
from starlette.authentication import has_required_scope, requires
|
||||
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
|
@ -20,6 +20,7 @@ from khoj.database.models import (
|
|||
LocalPdfConfig,
|
||||
LocalPlaintextConfig,
|
||||
NotionConfig,
|
||||
Subscription,
|
||||
)
|
||||
from khoj.routers.helpers import CommonQueryParams, update_telemetry_state
|
||||
from khoj.utils import constants, state
|
||||
|
@ -236,6 +237,10 @@ async def update_chat_model(
|
|||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
|
||||
if not subscribed:
|
||||
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
|
||||
|
||||
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
|
||||
|
||||
|
|
|
@ -70,13 +70,14 @@ def validate_conversation_config():
|
|||
if default_config is None:
|
||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||
|
||||
if default_config.model_type == "openai" and not ConversationAdapters.has_valid_openai_conversation_config():
|
||||
if default_config.model_type == "openai" and not default_config.openai_config:
|
||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||
|
||||
|
||||
async def is_ready_to_chat(user: KhojUser):
|
||||
has_openai_config = await ConversationAdapters.has_openai_chat()
|
||||
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
user_conversation_config = (await ConversationAdapters.aget_user_conversation_config(user)) or (
|
||||
await ConversationAdapters.aget_default_conversation_config()
|
||||
)
|
||||
|
||||
if user_conversation_config and user_conversation_config.model_type == "offline":
|
||||
chat_model = user_conversation_config.chat_model
|
||||
|
@ -86,8 +87,14 @@ async def is_ready_to_chat(user: KhojUser):
|
|||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
return True
|
||||
|
||||
if not has_openai_config:
|
||||
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
||||
if (
|
||||
user_conversation_config
|
||||
and user_conversation_config.model_type == "openai"
|
||||
and user_conversation_config.openai_config
|
||||
):
|
||||
return True
|
||||
|
||||
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
||||
|
||||
|
||||
def update_telemetry_state(
|
||||
|
@ -407,8 +414,9 @@ async def send_message_to_model_wrapper(
|
|||
)
|
||||
|
||||
elif conversation_config.model_type == "openai":
|
||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||
openai_chat_config = conversation_config.openai_config
|
||||
api_key = openai_chat_config.api_key
|
||||
api_base_url = openai_chat_config.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
|
@ -418,7 +426,11 @@ async def send_message_to_model_wrapper(
|
|||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
api_base_url=api_base_url,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
@ -480,7 +492,7 @@ def generate_chat_response(
|
|||
)
|
||||
|
||||
elif conversation_config.model_type == "openai":
|
||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
||||
openai_chat_config = conversation_config.openai_config
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model = conversation_config.chat_model
|
||||
chat_response = converse(
|
||||
|
@ -490,6 +502,7 @@ def generate_chat_response(
|
|||
conversation_log=meta_log,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
api_base_url=openai_chat_config.api_base_url,
|
||||
completion_func=partial_completion,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
|
|
|
@ -2,7 +2,7 @@ import os
|
|||
import threading
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from openai import OpenAI
|
||||
from whisper import Whisper
|
||||
|
@ -34,6 +34,7 @@ khoj_version: str = None
|
|||
device = get_device()
|
||||
chat_on_gpu: bool = True
|
||||
anonymous_mode: bool = False
|
||||
pretrained_tokenizers: Dict[str, Any] = dict()
|
||||
billing_enabled: bool = (
|
||||
os.getenv("STRIPE_API_KEY") is not None
|
||||
and os.getenv("STRIPE_SIGNING_SECRET") is not None
|
||||
|
|
Loading…
Reference in a new issue