diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index 96d8e51d..10fde9e8 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -623,10 +623,6 @@ class ConversationAdapters:
def get_openai_conversation_config():
return OpenAIProcessorConversationConfig.objects.filter().first()
- @staticmethod
- async def aget_openai_conversation_config():
- return await OpenAIProcessorConversationConfig.objects.filter().afirst()
-
@staticmethod
def has_valid_openai_conversation_config():
return OpenAIProcessorConversationConfig.objects.filter().exists()
@@ -659,7 +655,7 @@ class ConversationAdapters:
@staticmethod
async def aget_default_conversation_config():
- return await ChatModelOptions.objects.filter().afirst()
+ return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
@staticmethod
def save_conversation(
@@ -697,29 +693,15 @@ class ConversationAdapters:
user_conversation_config.setting = new_config
user_conversation_config.save()
- @staticmethod
- async def get_default_offline_llm():
- return await ChatModelOptions.objects.filter(model_type="offline").afirst()
-
@staticmethod
async def aget_user_conversation_config(user: KhojUser):
- config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
+ config = (
+ await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst()
+ )
if not config:
return None
return config.setting
- @staticmethod
- async def has_openai_chat():
- return await OpenAIProcessorConversationConfig.objects.filter().aexists()
-
- @staticmethod
- async def aget_default_openai_llm():
- return await ChatModelOptions.objects.filter(model_type="openai").afirst()
-
- @staticmethod
- async def get_openai_chat_config():
- return await OpenAIProcessorConversationConfig.objects.filter().afirst()
-
@staticmethod
async def get_speech_to_text_config():
return await SpeechToTextModelOptions.objects.filter().afirst()
@@ -744,7 +726,8 @@ class ConversationAdapters:
@staticmethod
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
- if conversation.agent and conversation.agent.chat_model:
+ agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
+ if agent and agent.chat_model:
conversation_config = conversation.agent.chat_model
else:
conversation_config = ConversationAdapters.get_conversation_config(user)
@@ -760,8 +743,7 @@ class ConversationAdapters:
return conversation_config
- openai_chat_config = ConversationAdapters.get_openai_conversation_config()
- if openai_chat_config and conversation_config.model_type == "openai":
+ if conversation_config.model_type == "openai" and conversation_config.openai_config:
return conversation_config
else:
diff --git a/src/khoj/database/migrations/0037_chatmodeloptions_openai_config_and_more.py b/src/khoj/database/migrations/0037_chatmodeloptions_openai_config_and_more.py
new file mode 100644
index 00000000..a7487e8c
--- /dev/null
+++ b/src/khoj/database/migrations/0037_chatmodeloptions_openai_config_and_more.py
@@ -0,0 +1,51 @@
+# Generated by Django 4.2.10 on 2024-04-24 05:46
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+def attach_openai_config(apps, schema_editor):
+ OpenAIProcessorConversationConfig = apps.get_model("database", "OpenAIProcessorConversationConfig")
+ openai_processor_conversation_config = OpenAIProcessorConversationConfig.objects.first()
+ if openai_processor_conversation_config:
+ ChatModelOptions = apps.get_model("database", "ChatModelOptions")
+ for chat_model_option in ChatModelOptions.objects.all():
+ if chat_model_option.model_type == "openai":
+ chat_model_option.openai_config = openai_processor_conversation_config
+ chat_model_option.save()
+
+
+def separate_openai_config(apps, schema_editor):
+ pass
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0036_delete_offlinechatprocessorconversationconfig"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="chatmodeloptions",
+ name="openai_config",
+ field=models.ForeignKey(
+ blank=True,
+ default=None,
+ null=True,
+ on_delete=django.db.models.deletion.CASCADE,
+ to="database.openaiprocessorconversationconfig",
+ ),
+ ),
+ migrations.AddField(
+ model_name="openaiprocessorconversationconfig",
+ name="api_base_url",
+ field=models.URLField(blank=True, default=None, null=True),
+ ),
+ migrations.AddField(
+ model_name="openaiprocessorconversationconfig",
+ name="name",
+ field=models.CharField(default="default", max_length=200),
+ preserve_default=False,
+ ),
+ migrations.RunPython(attach_openai_config, reverse_code=separate_openai_config),
+ ]
diff --git a/src/khoj/database/migrations/0038_merge_20240425_0857.py b/src/khoj/database/migrations/0038_merge_20240425_0857.py
new file mode 100644
index 00000000..3bb20914
--- /dev/null
+++ b/src/khoj/database/migrations/0038_merge_20240425_0857.py
@@ -0,0 +1,14 @@
+# Generated by Django 4.2.10 on 2024-04-25 08:57
+
+from typing import List
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0037_chatmodeloptions_openai_config_and_more"),
+ ("database", "0037_searchmodelconfig_bi_encoder_docs_encode_config_and_more"),
+ ]
+
+ operations: List[str] = []
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index ae13e980..4077c35c 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -73,6 +73,12 @@ class Subscription(BaseModel):
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
+class OpenAIProcessorConversationConfig(BaseModel):
+ name = models.CharField(max_length=200)
+ api_key = models.CharField(max_length=200)
+ api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
+
+
class ChatModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
@@ -82,6 +88,9 @@ class ChatModelOptions(BaseModel):
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
+ openai_config = models.ForeignKey(
+ OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
+ )
class Agent(BaseModel):
@@ -211,10 +220,6 @@ class TextToImageModelConfig(BaseModel):
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
-class OpenAIProcessorConversationConfig(BaseModel):
- api_key = models.CharField(max_length=200)
-
-
class SpeechToTextModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html
index f0537a2a..a6339bcf 100644
--- a/src/khoj/interface/web/config.html
+++ b/src/khoj/interface/web/config.html
@@ -191,9 +191,15 @@
+ {% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
+ {% else %}
+
+ {% endif %}
diff --git a/src/khoj/migrations/migrate_server_pg.py b/src/khoj/migrations/migrate_server_pg.py
index 0ab3522b..a770a38d 100644
--- a/src/khoj/migrations/migrate_server_pg.py
+++ b/src/khoj/migrations/migrate_server_pg.py
@@ -121,14 +121,16 @@ def migrate_server_pg(args):
if openai.get("chat-model") is None:
openai["chat-model"] = "gpt-3.5-turbo"
- OpenAIProcessorConversationConfig.objects.create(
- api_key=openai.get("api-key"),
+ openai_config = OpenAIProcessorConversationConfig.objects.create(
+ api_key=openai.get("api-key"), name="default"
)
+
ChatModelOptions.objects.create(
chat_model=openai.get("chat-model"),
tokenizer=processor_conversation.get("tokenizer"),
max_prompt_size=processor_conversation.get("max-prompt-size"),
model_type=ChatModelOptions.ModelType.OPENAI,
+ openai_config=openai_config,
)
save_config_to_file(raw_config, args.config_file)
diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py
index 6fb0ca0f..c25f05fd 100644
--- a/src/khoj/processor/conversation/openai/gpt.py
+++ b/src/khoj/processor/conversation/openai/gpt.py
@@ -23,6 +23,7 @@ def extract_questions(
model: Optional[str] = "gpt-4-turbo-preview",
conversation_log={},
api_key=None,
+ api_base_url=None,
temperature=0,
max_tokens=100,
location_data: LocationData = None,
@@ -64,12 +65,12 @@ def extract_questions(
# Get Response from GPT
response = completion_with_backoff(
messages=messages,
- completion_kwargs={"temperature": temperature, "max_tokens": max_tokens},
- model_kwargs={
- "model_name": model,
- "openai_api_key": api_key,
- "model_kwargs": {"response_format": {"type": "json_object"}},
- },
+ model=model,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ api_base_url=api_base_url,
+ model_kwargs={"response_format": {"type": "json_object"}},
+ openai_api_key=api_key,
)
# Extract, Clean Message from GPT's Response
@@ -89,7 +90,7 @@ def extract_questions(
return questions
-def send_message_to_model(messages, api_key, model, response_type="text"):
+def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None):
"""
Send message to model
"""
@@ -97,11 +98,10 @@ def send_message_to_model(messages, api_key, model, response_type="text"):
# Get Response from GPT
return completion_with_backoff(
messages=messages,
- model_kwargs={
- "model_name": model,
- "openai_api_key": api_key,
- "model_kwargs": {"response_format": {"type": response_type}},
- },
+ model=model,
+ openai_api_key=api_key,
+ api_base_url=api_base_url,
+ model_kwargs={"response_format": {"type": response_type}},
)
@@ -112,6 +112,7 @@ def converse(
conversation_log={},
model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,
+ api_base_url: Optional[str] = None,
temperature: float = 0.2,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
@@ -181,6 +182,7 @@ def converse(
model_name=model,
temperature=temperature,
openai_api_key=api_key,
+ api_base_url=api_base_url,
completion_func=completion_func,
model_kwargs={"stop": ["Notes:\n["]},
)
diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py
index 908a035d..0c37ba53 100644
--- a/src/khoj/processor/conversation/openai/utils.py
+++ b/src/khoj/processor/conversation/openai/utils.py
@@ -1,12 +1,9 @@
import logging
import os
from threading import Thread
-from typing import Any
+from typing import Dict
import openai
-from langchain.callbacks.base import BaseCallbackManager
-from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
-from langchain_openai import ChatOpenAI
from tenacity import (
before_sleep_log,
retry,
@@ -20,14 +17,7 @@ from khoj.processor.conversation.utils import ThreadedGenerator
logger = logging.getLogger(__name__)
-
-class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
- def __init__(self, gen: ThreadedGenerator):
- super().__init__()
- self.gen = gen
-
- def on_llm_new_token(self, token: str, **kwargs) -> Any:
- self.gen.send(token)
+openai_clients: Dict[str, openai.OpenAI] = {}
@retry(
@@ -43,13 +33,37 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
-def completion_with_backoff(messages, model_kwargs={}, completion_kwargs={}) -> str:
- if not "openai_api_key" in model_kwargs:
- model_kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
- llm = ChatOpenAI(**model_kwargs, request_timeout=20, max_retries=1)
+def completion_with_backoff(
+ messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, max_tokens=None
+) -> str:
+ client_key = f"{openai_api_key}--{api_base_url}"
+ client: openai.OpenAI = openai_clients.get(client_key)
+ if not client:
+ client = openai.OpenAI(
+ api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
+ base_url=api_base_url,
+ )
+ openai_clients[client_key] = client
+
+ formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
+
+ chat = client.chat.completions.create(
+ stream=True,
+ messages=formatted_messages, # type: ignore
+ model=model, # type: ignore
+ temperature=temperature,
+ timeout=20,
+ max_tokens=max_tokens,
+ **(model_kwargs or dict()),
+ )
aggregated_response = ""
- for chunk in llm.stream(messages, **completion_kwargs):
- aggregated_response += chunk.content
+ for chunk in chat:
+ delta_chunk = chunk.choices[0].delta # type: ignore
+ if isinstance(delta_chunk, str):
+ aggregated_response += delta_chunk
+ elif delta_chunk.content:
+ aggregated_response += delta_chunk.content
+
return aggregated_response
@@ -73,30 +87,45 @@ def chat_completion_with_backoff(
model_name,
temperature,
openai_api_key=None,
+ api_base_url=None,
completion_func=None,
model_kwargs=None,
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
- t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
+ t = Thread(
+ target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
+ )
t.start()
return g
-def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None):
- callback_handler = StreamingChatCallbackHandler(g)
- chat = ChatOpenAI(
- streaming=True,
- verbose=True,
- callback_manager=BaseCallbackManager([callback_handler]),
- model_name=model_name, # type: ignore
+def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
+ client_key = f"{openai_api_key}--{api_base_url}"
+ if client_key not in openai_clients:
+ client: openai.OpenAI = openai.OpenAI(
+ api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
+ base_url=api_base_url,
+ )
+ openai_clients[client_key] = client
+ else:
+ client: openai.OpenAI = openai_clients[client_key]
+
+ formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
+
+ chat = client.chat.completions.create(
+ stream=True,
+ messages=formatted_messages,
+ model=model_name, # type: ignore
temperature=temperature,
- openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
- model_kwargs=model_kwargs,
- request_timeout=20,
- max_retries=1,
- client=None,
+ timeout=20,
+ **(model_kwargs or dict()),
)
- chat(messages=messages)
+ for chunk in chat:
+ delta_chunk = chunk.choices[0].delta
+ if isinstance(delta_chunk, str):
+ g.send(delta_chunk)
+ elif delta_chunk.content:
+ g.send(delta_chunk.content)
g.close()
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index 877e5a43..c970c421 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -14,6 +14,7 @@ from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
+from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, merge_dicts
logger = logging.getLogger(__name__)
@@ -186,19 +187,31 @@ def truncate_messages(
max_prompt_size,
model_name: str,
loaded_model: Optional[Llama] = None,
- tokenizer_name="hf-internal-testing/llama-tokenizer",
+ tokenizer_name=None,
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
+ default_tokenizer = "hf-internal-testing/llama-tokenizer"
+
try:
if loaded_model:
encoder = loaded_model.tokenizer()
elif model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name)
+ elif tokenizer_name:
+ if tokenizer_name in state.pretrained_tokenizers:
+ encoder = state.pretrained_tokenizers[tokenizer_name]
+ else:
+ encoder = AutoTokenizer.from_pretrained(tokenizer_name)
+ state.pretrained_tokenizers[tokenizer_name] = encoder
else:
encoder = download_model(model_name).tokenizer()
except:
- encoder = AutoTokenizer.from_pretrained(tokenizer_name)
+ if default_tokenizer in state.pretrained_tokenizers:
+ encoder = state.pretrained_tokenizers[default_tokenizer]
+ else:
+ encoder = AutoTokenizer.from_pretrained(default_tokenizer)
+ state.pretrained_tokenizers[default_tokenizer] = encoder
logger.warning(
f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 362038b7..fe90698e 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -267,7 +267,6 @@ async def transcribe(
async def extract_references_and_questions(
request: Request,
- common: CommonQueryParams,
meta_log: dict,
q: str,
n: int,
@@ -303,14 +302,12 @@ async def extract_references_and_questions(
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
- conversation_config = await ConversationAdapters.aget_conversation_config(user)
- if conversation_config is None:
- conversation_config = await ConversationAdapters.aget_default_conversation_config()
+ conversation_config = await ConversationAdapters.aget_default_conversation_config()
+
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
using_offline_chat = True
- default_offline_llm = await ConversationAdapters.get_default_offline_llm()
- chat_model = default_offline_llm.chat_model
- max_tokens = default_offline_llm.max_prompt_size
+ chat_model = conversation_config.chat_model
+ max_tokens = conversation_config.max_prompt_size
if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
@@ -324,11 +321,10 @@ async def extract_references_and_questions(
location_data=location_data,
max_prompt_size=conversation_config.max_prompt_size,
)
- elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
- openai_chat_config = await ConversationAdapters.get_openai_chat_config()
- default_openai_llm = await ConversationAdapters.aget_default_openai_llm()
+ elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
+ openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key
- chat_model = default_openai_llm.chat_model
+ chat_model = conversation_config.chat_model
inferred_queries = extract_questions(
defiltered_query,
model=chat_model,
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index b4ec2454..07b2c656 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -380,7 +380,7 @@ async def websocket_endpoint(
q = q.replace(f"/{cmd.value}", "").strip()
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
- websocket, None, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update
+ websocket, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update
)
if compiled_references:
@@ -575,7 +575,7 @@ async def chat(
user_name = await aget_user_name(user)
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
- request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
+ request, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
)
online_results: Dict[str, Dict] = {}
diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py
index 64345e3d..dd84e317 100644
--- a/src/khoj/routers/api_config.py
+++ b/src/khoj/routers/api_config.py
@@ -7,7 +7,7 @@ from asgiref.sync import sync_to_async
from fastapi import APIRouter, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response
-from starlette.authentication import requires
+from starlette.authentication import has_required_scope, requires
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
@@ -20,6 +20,7 @@ from khoj.database.models import (
LocalPdfConfig,
LocalPlaintextConfig,
NotionConfig,
+ Subscription,
)
from khoj.routers.helpers import CommonQueryParams, update_telemetry_state
from khoj.utils import constants, state
@@ -236,6 +237,10 @@ async def update_chat_model(
client: Optional[str] = None,
):
user = request.user.object
+ subscribed = has_required_scope(request, ["premium"])
+
+ if not subscribed:
+ raise HTTPException(status_code=403, detail="User is not subscribed to premium")
new_config = await ConversationAdapters.aset_user_conversation_processor(user, int(id))
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index 96c713e0..af33564f 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -70,13 +70,14 @@ def validate_conversation_config():
if default_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
- if default_config.model_type == "openai" and not ConversationAdapters.has_valid_openai_conversation_config():
+ if default_config.model_type == "openai" and not default_config.openai_config:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
async def is_ready_to_chat(user: KhojUser):
- has_openai_config = await ConversationAdapters.has_openai_chat()
- user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
+ user_conversation_config = (await ConversationAdapters.aget_user_conversation_config(user)) or (
+ await ConversationAdapters.aget_default_conversation_config()
+ )
if user_conversation_config and user_conversation_config.model_type == "offline":
chat_model = user_conversation_config.chat_model
@@ -86,8 +87,14 @@ async def is_ready_to_chat(user: KhojUser):
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return True
- if not has_openai_config:
- raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
+ if (
+ user_conversation_config
+ and user_conversation_config.model_type == "openai"
+ and user_conversation_config.openai_config
+ ):
+ return True
+
+ raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
def update_telemetry_state(
@@ -407,8 +414,9 @@ async def send_message_to_model_wrapper(
)
elif conversation_config.model_type == "openai":
- openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
+ openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key
+ api_base_url = openai_chat_config.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
@@ -418,7 +426,11 @@ async def send_message_to_model_wrapper(
)
openai_response = send_message_to_model(
- messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
+ messages=truncated_messages,
+ api_key=api_key,
+ model=chat_model,
+ response_type=response_type,
+ api_base_url=api_base_url,
)
return openai_response
@@ -480,7 +492,7 @@ def generate_chat_response(
)
elif conversation_config.model_type == "openai":
- openai_chat_config = ConversationAdapters.get_openai_conversation_config()
+ openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model
chat_response = converse(
@@ -490,6 +502,7 @@ def generate_chat_response(
conversation_log=meta_log,
model=chat_model,
api_key=api_key,
+ api_base_url=openai_chat_config.api_base_url,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py
index b431225d..8270a70f 100644
--- a/src/khoj/utils/state.py
+++ b/src/khoj/utils/state.py
@@ -2,7 +2,7 @@ import os
import threading
from collections import defaultdict
from pathlib import Path
-from typing import Dict, List
+from typing import Any, Dict, List
from openai import OpenAI
from whisper import Whisper
@@ -34,6 +34,7 @@ khoj_version: str = None
device = get_device()
chat_on_gpu: bool = True
anonymous_mode: bool = False
+pretrained_tokenizers: Dict[str, Any] = dict()
billing_enabled: bool = (
os.getenv("STRIPE_API_KEY") is not None
and os.getenv("STRIPE_SIGNING_SECRET") is not None