mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Honor user's chat settings when running the extract questions phase
- Add marginally better error handling when GPT gives a messed up respones to the extract questions method - Remove debug log lines
This commit is contained in:
parent
67156e6aec
commit
a8a25ceac2
3 changed files with 46 additions and 21 deletions
|
@ -240,10 +240,18 @@ class ConversationAdapters:
|
||||||
def get_openai_conversation_config():
|
def get_openai_conversation_config():
|
||||||
return OpenAIProcessorConversationConfig.objects.filter().first()
|
return OpenAIProcessorConversationConfig.objects.filter().first()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_openai_conversation_config():
|
||||||
|
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_offline_chat_conversation_config():
|
def get_offline_chat_conversation_config():
|
||||||
return OfflineChatProcessorConversationConfig.objects.filter().first()
|
return OfflineChatProcessorConversationConfig.objects.filter().first()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_offline_chat_conversation_config():
|
||||||
|
return await OfflineChatProcessorConversationConfig.objects.filter().afirst()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def has_valid_offline_conversation_config():
|
def has_valid_offline_conversation_config():
|
||||||
return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists()
|
return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists()
|
||||||
|
@ -267,10 +275,21 @@ class ConversationAdapters:
|
||||||
return None
|
return None
|
||||||
return config.setting
|
return config.setting
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_conversation_config(user: KhojUser):
|
||||||
|
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
return config.setting
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default_conversation_config():
|
def get_default_conversation_config():
|
||||||
return ChatModelOptions.objects.filter().first()
|
return ChatModelOptions.objects.filter().first()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_default_conversation_config():
|
||||||
|
return await ChatModelOptions.objects.filter().afirst()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_conversation(user: KhojUser, conversation_log: dict):
|
def save_conversation(user: KhojUser, conversation_log: dict):
|
||||||
conversation = Conversation.objects.filter(user=user)
|
conversation = Conversation.objects.filter(user=user)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
import logging
|
import logging
|
||||||
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -31,6 +32,10 @@ def extract_questions(
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _valid_question(question: str):
|
||||||
|
return not is_none_or_empty(question) and question != "[]"
|
||||||
|
|
||||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
chat_history = "".join(
|
chat_history = "".join(
|
||||||
[
|
[
|
||||||
|
@ -70,7 +75,7 @@ def extract_questions(
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
try:
|
try:
|
||||||
questions = (
|
split_questions = (
|
||||||
response.content.strip(empty_escape_sequences)
|
response.content.strip(empty_escape_sequences)
|
||||||
.replace("['", '["')
|
.replace("['", '["')
|
||||||
.replace("']", '"]')
|
.replace("']", '"]')
|
||||||
|
@ -79,9 +84,18 @@ def extract_questions(
|
||||||
.replace('"]', "")
|
.replace('"]', "")
|
||||||
.split('", "')
|
.split('", "')
|
||||||
)
|
)
|
||||||
|
questions = []
|
||||||
|
|
||||||
|
for question in split_questions:
|
||||||
|
if question not in questions and _valid_question(question):
|
||||||
|
questions.append(question)
|
||||||
|
|
||||||
|
if is_none_or_empty(questions):
|
||||||
|
raise ValueError("GPT returned empty JSON")
|
||||||
except:
|
except:
|
||||||
logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||||
questions = [text]
|
questions = [text]
|
||||||
|
|
||||||
logger.debug(f"Extracted Questions by GPT: {questions}")
|
logger.debug(f"Extracted Questions by GPT: {questions}")
|
||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,7 @@ from database.models import (
|
||||||
Entry as DbEntry,
|
Entry as DbEntry,
|
||||||
GithubConfig,
|
GithubConfig,
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
|
ChatModelOptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -669,7 +670,16 @@ async def extract_references_and_questions(
|
||||||
# Infer search queries from user message
|
# Infer search queries from user message
|
||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||||
if await ConversationAdapters.ahas_offline_chat():
|
offline_chat_config = await ConversationAdapters.aget_offline_chat_conversation_config()
|
||||||
|
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
||||||
|
if conversation_config is None:
|
||||||
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||||
|
if (
|
||||||
|
offline_chat_config
|
||||||
|
and offline_chat_config.enabled
|
||||||
|
and conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE.value
|
||||||
|
):
|
||||||
using_offline_chat = True
|
using_offline_chat = True
|
||||||
offline_chat = await ConversationAdapters.get_offline_chat()
|
offline_chat = await ConversationAdapters.get_offline_chat()
|
||||||
chat_model = offline_chat.chat_model
|
chat_model = offline_chat.chat_model
|
||||||
|
@ -681,7 +691,7 @@ async def extract_references_and_questions(
|
||||||
inferred_queries = extract_questions_offline(
|
inferred_queries = extract_questions_offline(
|
||||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||||
)
|
)
|
||||||
elif await ConversationAdapters.has_openai_chat():
|
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI.value:
|
||||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||||
openai_chat = await ConversationAdapters.get_openai_chat()
|
openai_chat = await ConversationAdapters.get_openai_chat()
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
|
@ -690,11 +700,6 @@ async def extract_references_and_questions(
|
||||||
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
|
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"🔍 Inferred queries: {inferred_queries}")
|
|
||||||
logger.info(f"🔍 Defiltered query: {defiltered_query}")
|
|
||||||
logger.info(f"using max distance: {d}")
|
|
||||||
logger.info(f"using filters: {filters_in_query}")
|
|
||||||
logger.info(f"Max results: {n}")
|
|
||||||
# Collate search results as context for GPT
|
# Collate search results as context for GPT
|
||||||
with timer("Searching knowledge base took", logger):
|
with timer("Searching knowledge base took", logger):
|
||||||
result_list = []
|
result_list = []
|
||||||
|
@ -711,20 +716,7 @@ async def extract_references_and_questions(
|
||||||
common=common,
|
common=common,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info(f"🔍 Found {len(result_list)} results")
|
|
||||||
logger.info(f"Confidence scores: {[item.score for item in result_list]}")
|
|
||||||
# Dedupe the results again, as duplicates may be returned across queries.
|
|
||||||
with open("compiled_references_pre_deduped.txt", "w") as f:
|
|
||||||
for item in compiled_references:
|
|
||||||
f.write(f"{item}\n")
|
|
||||||
|
|
||||||
result_list = text_search.deduplicated_search_responses(result_list)
|
result_list = text_search.deduplicated_search_responses(result_list)
|
||||||
compiled_references = [item.additional["compiled"] for item in result_list]
|
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||||
|
|
||||||
with open("compiled_references_deduped.txt", "w") as f:
|
|
||||||
for item in compiled_references:
|
|
||||||
f.write(f"{item}\n")
|
|
||||||
|
|
||||||
logger.info(f"🔍 Deduped results: {len(result_list)}")
|
|
||||||
|
|
||||||
return compiled_references, inferred_queries, defiltered_query
|
return compiled_references, inferred_queries, defiltered_query
|
||||||
|
|
Loading…
Reference in a new issue