Support customization of the OpenAI base url in admin settings (#725)

- Allow self-hosted users to customize their open ai base url. This allows you to easily use a proxy service and extend support for other models.
- This also includes a migration that associates any existing openai chat model configuration with an openai processor configuration
- Make changing model a paid/subscriber feature
- Removes usage of langchain's OpenAI wrapper for better control over parsing input/output
This commit is contained in:
sabaimran 2024-04-27 05:54:35 -07:00 committed by GitHub
parent 49834e3b00
commit 2047b0c973
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 219 additions and 100 deletions

View file

@ -623,10 +623,6 @@ class ConversationAdapters:
def get_openai_conversation_config():
return OpenAIProcessorConversationConfig.objects.filter().first()
@staticmethod
async def aget_openai_conversation_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
@staticmethod
def has_valid_openai_conversation_config():
return OpenAIProcessorConversationConfig.objects.filter().exists()
@ -659,7 +655,7 @@ class ConversationAdapters:
@staticmethod
async def aget_default_conversation_config():
return await ChatModelOptions.objects.filter().afirst()
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
@staticmethod
def save_conversation(
@ -697,29 +693,15 @@ class ConversationAdapters:
user_conversation_config.setting = new_config
user_conversation_config.save()
@staticmethod
async def get_default_offline_llm():
return await ChatModelOptions.objects.filter(model_type="offline").afirst()
@staticmethod
async def aget_user_conversation_config(user: KhojUser):
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
config = (
await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst()
)
if not config:
return None
return config.setting
@staticmethod
async def has_openai_chat():
return await OpenAIProcessorConversationConfig.objects.filter().aexists()
@staticmethod
async def aget_default_openai_llm():
return await ChatModelOptions.objects.filter(model_type="openai").afirst()
@staticmethod
async def get_openai_chat_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
@staticmethod
async def get_speech_to_text_config():
return await SpeechToTextModelOptions.objects.filter().afirst()
@ -744,7 +726,8 @@ class ConversationAdapters:
@staticmethod
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
if conversation.agent and conversation.agent.chat_model:
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
if agent and agent.chat_model:
conversation_config = conversation.agent.chat_model
else:
conversation_config = ConversationAdapters.get_conversation_config(user)
@ -760,8 +743,7 @@ class ConversationAdapters:
return conversation_config
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
if openai_chat_config and conversation_config.model_type == "openai":
if conversation_config.model_type == "openai" and conversation_config.openai_config:
return conversation_config
else:

View file

@ -0,0 +1,51 @@
# Generated by Django 4.2.10 on 2024-04-24 05:46
import django.db.models.deletion
from django.db import migrations, models
def attach_openai_config(apps, schema_editor):
OpenAIProcessorConversationConfig = apps.get_model("database", "OpenAIProcessorConversationConfig")
openai_processor_conversation_config = OpenAIProcessorConversationConfig.objects.first()
if openai_processor_conversation_config:
ChatModelOptions = apps.get_model("database", "ChatModelOptions")
for chat_model_option in ChatModelOptions.objects.all():
if chat_model_option.model_type == "openai":
chat_model_option.openai_config = openai_processor_conversation_config
chat_model_option.save()
def separate_openai_config(apps, schema_editor):
pass
class Migration(migrations.Migration):
dependencies = [
("database", "0036_delete_offlinechatprocessorconversationconfig"),
]
operations = [
migrations.AddField(
model_name="chatmodeloptions",
name="openai_config",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.openaiprocessorconversationconfig",
),
),
migrations.AddField(
model_name="openaiprocessorconversationconfig",
name="api_base_url",
field=models.URLField(blank=True, default=None, null=True),
),
migrations.AddField(
model_name="openaiprocessorconversationconfig",
name="name",
field=models.CharField(default="default", max_length=200),
preserve_default=False,
),
migrations.RunPython(attach_openai_config, reverse_code=separate_openai_config),
]

View file

@ -0,0 +1,14 @@
# Generated by Django 4.2.10 on 2024-04-25 08:57
from typing import List
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0037_chatmodeloptions_openai_config_and_more"),
("database", "0037_searchmodelconfig_bi_encoder_docs_encode_config_and_more"),
]
operations: List[str] = []

View file

@ -73,6 +73,12 @@ class Subscription(BaseModel):
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
class OpenAIProcessorConversationConfig(BaseModel):
name = models.CharField(max_length=200)
api_key = models.CharField(max_length=200)
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
class ChatModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
@ -82,6 +88,9 @@ class ChatModelOptions(BaseModel):
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
openai_config = models.ForeignKey(
OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
)
class Agent(BaseModel):
@ -211,10 +220,6 @@ class TextToImageModelConfig(BaseModel):
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200)
class SpeechToTextModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"

View file

@ -191,9 +191,15 @@
</select>
</div>
<div class="card-action-row">
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
<button id="save-model" class="card-button happy" onclick="updateChatModel()">
Save
</button>
{% else %}
<button id="save-model" class="card-button" disabled>
Subscribe to use different models
</button>
{% endif %}
</div>
</div>
</div>

View file

@ -121,14 +121,16 @@ def migrate_server_pg(args):
if openai.get("chat-model") is None:
openai["chat-model"] = "gpt-3.5-turbo"
OpenAIProcessorConversationConfig.objects.create(
api_key=openai.get("api-key"),
openai_config = OpenAIProcessorConversationConfig.objects.create(
api_key=openai.get("api-key"), name="default"
)
ChatModelOptions.objects.create(
chat_model=openai.get("chat-model"),
tokenizer=processor_conversation.get("tokenizer"),
max_prompt_size=processor_conversation.get("max-prompt-size"),
model_type=ChatModelOptions.ModelType.OPENAI,
openai_config=openai_config,
)
save_config_to_file(raw_config, args.config_file)

View file

@ -23,6 +23,7 @@ def extract_questions(
model: Optional[str] = "gpt-4-turbo-preview",
conversation_log={},
api_key=None,
api_base_url=None,
temperature=0,
max_tokens=100,
location_data: LocationData = None,
@ -64,12 +65,12 @@ def extract_questions(
# Get Response from GPT
response = completion_with_backoff(
messages=messages,
completion_kwargs={"temperature": temperature, "max_tokens": max_tokens},
model_kwargs={
"model_name": model,
"openai_api_key": api_key,
"model_kwargs": {"response_format": {"type": "json_object"}},
},
model=model,
temperature=temperature,
max_tokens=max_tokens,
api_base_url=api_base_url,
model_kwargs={"response_format": {"type": "json_object"}},
openai_api_key=api_key,
)
# Extract, Clean Message from GPT's Response
@ -89,7 +90,7 @@ def extract_questions(
return questions
def send_message_to_model(messages, api_key, model, response_type="text"):
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None):
"""
Send message to model
"""
@ -97,11 +98,10 @@ def send_message_to_model(messages, api_key, model, response_type="text"):
# Get Response from GPT
return completion_with_backoff(
messages=messages,
model_kwargs={
"model_name": model,
"openai_api_key": api_key,
"model_kwargs": {"response_format": {"type": response_type}},
},
model=model,
openai_api_key=api_key,
api_base_url=api_base_url,
model_kwargs={"response_format": {"type": response_type}},
)
@ -112,6 +112,7 @@ def converse(
conversation_log={},
model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
temperature: float = 0.2,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
@ -181,6 +182,7 @@ def converse(
model_name=model,
temperature=temperature,
openai_api_key=api_key,
api_base_url=api_base_url,
completion_func=completion_func,
model_kwargs={"stop": ["Notes:\n["]},
)

View file

@ -1,12 +1,9 @@
import logging
import os
from threading import Thread
from typing import Any
from typing import Dict
import openai
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_openai import ChatOpenAI
from tenacity import (
before_sleep_log,
retry,
@ -20,14 +17,7 @@ from khoj.processor.conversation.utils import ThreadedGenerator
logger = logging.getLogger(__name__)
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
def __init__(self, gen: ThreadedGenerator):
super().__init__()
self.gen = gen
def on_llm_new_token(self, token: str, **kwargs) -> Any:
self.gen.send(token)
openai_clients: Dict[str, openai.OpenAI] = {}
@retry(
@ -43,13 +33,37 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def completion_with_backoff(messages, model_kwargs={}, completion_kwargs={}) -> str:
if not "openai_api_key" in model_kwargs:
model_kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(**model_kwargs, request_timeout=20, max_retries=1)
def completion_with_backoff(
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, max_tokens=None
) -> str:
client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI = openai_clients.get(client_key)
if not client:
client = openai.OpenAI(
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
base_url=api_base_url,
)
openai_clients[client_key] = client
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
chat = client.chat.completions.create(
stream=True,
messages=formatted_messages, # type: ignore
model=model, # type: ignore
temperature=temperature,
timeout=20,
max_tokens=max_tokens,
**(model_kwargs or dict()),
)
aggregated_response = ""
for chunk in llm.stream(messages, **completion_kwargs):
aggregated_response += chunk.content
for chunk in chat:
delta_chunk = chunk.choices[0].delta # type: ignore
if isinstance(delta_chunk, str):
aggregated_response += delta_chunk
elif delta_chunk.content:
aggregated_response += delta_chunk.content
return aggregated_response
@ -73,30 +87,45 @@ def chat_completion_with_backoff(
model_name,
temperature,
openai_api_key=None,
api_base_url=None,
completion_func=None,
model_kwargs=None,
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
t = Thread(
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
)
t.start()
return g
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None):
callback_handler = StreamingChatCallbackHandler(g)
chat = ChatOpenAI(
streaming=True,
verbose=True,
callback_manager=BaseCallbackManager([callback_handler]),
model_name=model_name, # type: ignore
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
client_key = f"{openai_api_key}--{api_base_url}"
if client_key not in openai_clients:
client: openai.OpenAI = openai.OpenAI(
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
base_url=api_base_url,
)
openai_clients[client_key] = client
else:
client: openai.OpenAI = openai_clients[client_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
chat = client.chat.completions.create(
stream=True,
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
model_kwargs=model_kwargs,
request_timeout=20,
max_retries=1,
client=None,
timeout=20,
**(model_kwargs or dict()),
)
chat(messages=messages)
for chunk in chat:
delta_chunk = chunk.choices[0].delta
if isinstance(delta_chunk, str):
g.send(delta_chunk)
elif delta_chunk.content:
g.send(delta_chunk.content)
g.close()

View file

@ -14,6 +14,7 @@ from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, merge_dicts
logger = logging.getLogger(__name__)
@ -186,19 +187,31 @@ def truncate_messages(
max_prompt_size,
model_name: str,
loaded_model: Optional[Llama] = None,
tokenizer_name="hf-internal-testing/llama-tokenizer",
tokenizer_name=None,
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
default_tokenizer = "hf-internal-testing/llama-tokenizer"
try:
if loaded_model:
encoder = loaded_model.tokenizer()
elif model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name)
elif tokenizer_name:
if tokenizer_name in state.pretrained_tokenizers:
encoder = state.pretrained_tokenizers[tokenizer_name]
else:
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
state.pretrained_tokenizers[tokenizer_name] = encoder
else:
encoder = download_model(model_name).tokenizer()
except:
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
if default_tokenizer in state.pretrained_tokenizers:
encoder = state.pretrained_tokenizers[default_tokenizer]
else:
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
state.pretrained_tokenizers[default_tokenizer] = encoder
logger.warning(
f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)

View file

@ -267,7 +267,6 @@ async def transcribe(
async def extract_references_and_questions(
request: Request,
common: CommonQueryParams,
meta_log: dict,
q: str,
n: int,
@ -303,14 +302,12 @@ async def extract_references_and_questions(
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
conversation_config = await ConversationAdapters.aget_conversation_config(user)
if conversation_config is None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
conversation_config = await ConversationAdapters.aget_default_conversation_config()
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
using_offline_chat = True
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
chat_model = default_offline_llm.chat_model
max_tokens = default_offline_llm.max_prompt_size
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
@ -324,11 +321,10 @@ async def extract_references_and_questions(
location_data=location_data,
max_prompt_size=conversation_config.max_prompt_size,
)
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
default_openai_llm = await ConversationAdapters.aget_default_openai_llm()
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key
chat_model = default_openai_llm.chat_model
chat_model = conversation_config.chat_model
inferred_queries = extract_questions(
defiltered_query,
model=chat_model,

View file

@ -380,7 +380,7 @@ async def websocket_endpoint(
q = q.replace(f"/{cmd.value}", "").strip()
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
websocket, None, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update
websocket, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update
)
if compiled_references:
@ -575,7 +575,7 @@ async def chat(
user_name = await aget_user_name(user)
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
request, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
)
online_results: Dict[str, Dict] = {}

View file

@ -7,7 +7,7 @@ from asgiref.sync import sync_to_async
from fastapi import APIRouter, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response
from starlette.authentication import requires
from starlette.authentication import has_required_scope, requires
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
@ -20,6 +20,7 @@ from khoj.database.models import (
LocalPdfConfig,
LocalPlaintextConfig,
NotionConfig,
Subscription,
)
from khoj.routers.helpers import CommonQueryParams, update_telemetry_state
from khoj.utils import constants, state
@ -236,6 +237,10 @@ async def update_chat_model(
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))

View file

@ -70,13 +70,14 @@ def validate_conversation_config():
if default_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
if default_config.model_type == "openai" and not ConversationAdapters.has_valid_openai_conversation_config():
if default_config.model_type == "openai" and not default_config.openai_config:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
async def is_ready_to_chat(user: KhojUser):
has_openai_config = await ConversationAdapters.has_openai_chat()
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
user_conversation_config = (await ConversationAdapters.aget_user_conversation_config(user)) or (
await ConversationAdapters.aget_default_conversation_config()
)
if user_conversation_config and user_conversation_config.model_type == "offline":
chat_model = user_conversation_config.chat_model
@ -86,8 +87,14 @@ async def is_ready_to_chat(user: KhojUser):
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return True
if not has_openai_config:
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
if (
user_conversation_config
and user_conversation_config.model_type == "openai"
and user_conversation_config.openai_config
):
return True
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
def update_telemetry_state(
@ -407,8 +414,9 @@ async def send_message_to_model_wrapper(
)
elif conversation_config.model_type == "openai":
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key
api_base_url = openai_chat_config.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
@ -418,7 +426,11 @@ async def send_message_to_model_wrapper(
)
openai_response = send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
api_base_url=api_base_url,
)
return openai_response
@ -480,7 +492,7 @@ def generate_chat_response(
)
elif conversation_config.model_type == "openai":
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model
chat_response = converse(
@ -490,6 +502,7 @@ def generate_chat_response(
conversation_log=meta_log,
model=chat_model,
api_key=api_key,
api_base_url=openai_chat_config.api_base_url,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,

View file

@ -2,7 +2,7 @@ import os
import threading
from collections import defaultdict
from pathlib import Path
from typing import Dict, List
from typing import Any, Dict, List
from openai import OpenAI
from whisper import Whisper
@ -34,6 +34,7 @@ khoj_version: str = None
device = get_device()
chat_on_gpu: bool = True
anonymous_mode: bool = False
pretrained_tokenizers: Dict[str, Any] = dict()
billing_enabled: bool = (
os.getenv("STRIPE_API_KEY") is not None
and os.getenv("STRIPE_SIGNING_SECRET") is not None