mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Let Offline chat override OpenAI API settings (#362)
* Let Offline chat override OpenAI API settings * Download the offline model whenever offline chat is enabled * Add progressbar for download for llamav2 model to track progress * Change ordering of n due to switch of default processor * Flip ordering of offline/openai checks when extracting questions from query
This commit is contained in:
parent
12cfb48f16
commit
9f78db0579
4 changed files with 30 additions and 19 deletions
|
@ -2,7 +2,7 @@ import os
|
|||
import logging
|
||||
import requests
|
||||
from gpt4all import GPT4All
|
||||
import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from khoj.processor.conversation.gpt4all import model_metadata
|
||||
|
||||
|
@ -24,9 +24,17 @@ def download_model(model_name):
|
|||
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(filename, "wb") as f:
|
||||
total_size = int(r.headers.get("content-length", 0))
|
||||
with open(filename, "wb") as f, tqdm(
|
||||
unit="B", # unit string to be displayed.
|
||||
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
|
||||
unit_divisor=1024, # is used when unit_scale is true
|
||||
total=total_size, # the total iteration.
|
||||
desc=filename.split("/")[-1], # prefix to be displayed on progress bar.
|
||||
) as progress_bar:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
return GPT4All(model_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}")
|
||||
|
|
|
@ -709,21 +709,22 @@ async def extract_references_and_questions(
|
|||
if conversation_type == "notes":
|
||||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
if state.processor_config.conversation and state.processor_config.conversation.openai_model:
|
||||
api_key = state.processor_config.conversation.openai_model.api_key
|
||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
|
||||
else:
|
||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||
if state.processor_config.conversation.enable_offline_chat:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
inferred_queries = extract_questions_offline(
|
||||
q, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||
)
|
||||
elif state.processor_config.conversation.openai_model:
|
||||
api_key = state.processor_config.conversation.openai_model.api_key
|
||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
with timer("Searching knowledge base took", logger):
|
||||
result_list = []
|
||||
for query in inferred_queries:
|
||||
n_items = n if state.processor_config.conversation.openai_model else min(n, 3)
|
||||
n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n
|
||||
result_list.extend(
|
||||
await search(query, request=request, n=n_items, r=True, score_threshold=-5.0, dedupe=False)
|
||||
)
|
||||
|
|
|
@ -86,6 +86,7 @@ def generate_chat_response(
|
|||
# Switch to general conversation type if no relevant notes found for the given query
|
||||
conversation_type = "notes" if compiled_references else "general"
|
||||
logger.debug(f"Conversation Type: {conversation_type}")
|
||||
chat_response = None
|
||||
|
||||
try:
|
||||
with timer("Generating chat response took", logger):
|
||||
|
@ -98,7 +99,17 @@ def generate_chat_response(
|
|||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
if state.processor_config.conversation.openai_model:
|
||||
if state.processor_config.conversation.enable_offline_chat:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
chat_response = converse_offline(
|
||||
references=compiled_references,
|
||||
user_query=q,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
|
||||
elif state.processor_config.conversation.openai_model:
|
||||
api_key = state.processor_config.conversation.openai_model.api_key
|
||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||
chat_response = converse(
|
||||
|
@ -109,15 +120,6 @@ def generate_chat_response(
|
|||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
else:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
chat_response = converse_offline(
|
||||
references=compiled_references,
|
||||
user_query=q,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
|
|
|
@ -94,7 +94,7 @@ class ConversationProcessorConfigModel:
|
|||
self.chat_session: List[str] = []
|
||||
self.meta_log: dict = {}
|
||||
|
||||
if not self.openai_model and self.enable_offline_chat:
|
||||
if self.enable_offline_chat:
|
||||
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
|
||||
else:
|
||||
self.gpt4all_model.loaded_model = None
|
||||
|
|
Loading…
Reference in a new issue