mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Merge pull request #554 from khoj-ai/fix/issues-with-prod-chat
Fix misc. issues with chat configuration
This commit is contained in:
commit
a5613cb08a
5 changed files with 49 additions and 11 deletions
|
@ -240,10 +240,18 @@ class ConversationAdapters:
|
|||
def get_openai_conversation_config():
|
||||
return OpenAIProcessorConversationConfig.objects.filter().first()
|
||||
|
||||
@staticmethod
|
||||
async def aget_openai_conversation_config():
|
||||
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
def get_offline_chat_conversation_config():
|
||||
return OfflineChatProcessorConversationConfig.objects.filter().first()
|
||||
|
||||
@staticmethod
|
||||
async def aget_offline_chat_conversation_config():
|
||||
return await OfflineChatProcessorConversationConfig.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
def has_valid_offline_conversation_config():
|
||||
return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists()
|
||||
|
@ -267,10 +275,21 @@ class ConversationAdapters:
|
|||
return None
|
||||
return config.setting
|
||||
|
||||
@staticmethod
|
||||
async def aget_conversation_config(user: KhojUser):
|
||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
if not config:
|
||||
return None
|
||||
return config.setting
|
||||
|
||||
@staticmethod
|
||||
def get_default_conversation_config():
|
||||
return ChatModelOptions.objects.filter().first()
|
||||
|
||||
@staticmethod
|
||||
async def aget_default_conversation_config():
|
||||
return await ChatModelOptions.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
def save_conversation(user: KhojUser, conversation_log: dict):
|
||||
conversation = Conversation.objects.filter(user=user)
|
||||
|
@ -320,10 +339,6 @@ class ConversationAdapters:
|
|||
async def get_openai_chat_config():
|
||||
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
async def aget_default_conversation_config():
|
||||
return await ChatModelOptions.objects.filter().afirst()
|
||||
|
||||
|
||||
class EntryAdapters:
|
||||
word_filer = WordFilter()
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Standard Packages
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
|
@ -31,6 +32,10 @@ def extract_questions(
|
|||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
|
||||
def _valid_question(question: str):
|
||||
return not is_none_or_empty(question) and question != "[]"
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = "".join(
|
||||
[
|
||||
|
@ -70,7 +75,7 @@ def extract_questions(
|
|||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
try:
|
||||
questions = (
|
||||
split_questions = (
|
||||
response.content.strip(empty_escape_sequences)
|
||||
.replace("['", '["')
|
||||
.replace("']", '"]')
|
||||
|
@ -79,9 +84,18 @@ def extract_questions(
|
|||
.replace('"]', "")
|
||||
.split('", "')
|
||||
)
|
||||
questions = []
|
||||
|
||||
for question in split_questions:
|
||||
if question not in questions and _valid_question(question):
|
||||
questions.append(question)
|
||||
|
||||
if is_none_or_empty(questions):
|
||||
raise ValueError("GPT returned empty JSON")
|
||||
except:
|
||||
logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||
questions = [text]
|
||||
|
||||
logger.debug(f"Extracted Questions by GPT: {questions}")
|
||||
return questions
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ from database.models import (
|
|||
Entry as DbEntry,
|
||||
GithubConfig,
|
||||
NotionConfig,
|
||||
ChatModelOptions,
|
||||
)
|
||||
|
||||
|
||||
|
@ -122,7 +123,7 @@ async def map_config_to_db(config: FullConfig, user: KhojUser):
|
|||
def _initialize_config():
|
||||
if state.config is None:
|
||||
state.config = FullConfig()
|
||||
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
||||
state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
|
||||
|
||||
|
||||
@api.get("/config/data", response_model=FullConfig)
|
||||
|
@ -669,7 +670,16 @@ async def extract_references_and_questions(
|
|||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||
if await ConversationAdapters.ahas_offline_chat():
|
||||
offline_chat_config = await ConversationAdapters.aget_offline_chat_conversation_config()
|
||||
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
||||
if conversation_config is None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||
if (
|
||||
offline_chat_config
|
||||
and offline_chat_config.enabled
|
||||
and conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE
|
||||
):
|
||||
using_offline_chat = True
|
||||
offline_chat = await ConversationAdapters.get_offline_chat()
|
||||
chat_model = offline_chat.chat_model
|
||||
|
@ -681,7 +691,7 @@ async def extract_references_and_questions(
|
|||
inferred_queries = extract_questions_offline(
|
||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||
)
|
||||
elif await ConversationAdapters.has_openai_chat():
|
||||
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
openai_chat = await ConversationAdapters.get_openai_chat()
|
||||
api_key = openai_chat_config.api_key
|
||||
|
@ -706,7 +716,6 @@ async def extract_references_and_questions(
|
|||
common=common,
|
||||
)
|
||||
)
|
||||
# Dedupe the results again, as duplicates may be returned across queries.
|
||||
result_list = text_search.deduplicated_search_responses(result_list)
|
||||
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||
|
||||
|
|
|
@ -163,7 +163,7 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
|||
|
||||
else:
|
||||
hit_ids.add(hit.corpus_id)
|
||||
yield SearchResponse.parse_obj(
|
||||
yield SearchResponse.model_validate(
|
||||
{
|
||||
"entry": hit.entry,
|
||||
"score": hit.score,
|
||||
|
|
|
@ -39,7 +39,7 @@ def load_config_from_file(yaml_config_file: Path) -> dict:
|
|||
|
||||
def parse_config_from_string(yaml_config: dict) -> FullConfig:
|
||||
"Parse and validate config in YML string"
|
||||
return FullConfig.parse_obj(yaml_config)
|
||||
return FullConfig.model_validate(yaml_config)
|
||||
|
||||
|
||||
def parse_config_from_file(yaml_config_file):
|
||||
|
|
Loading…
Reference in a new issue