mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 16:14:21 +00:00
Add subqueries for internet-connected search results and update client-side code accordingly
- Add a wrapper method to help make direct queries to the LLM and determine any intermediate responses needed for handling the request
This commit is contained in:
parent
b8e6883a81
commit
fee99779bf
9 changed files with 253 additions and 86 deletions
|
@ -4,6 +4,7 @@ from datetime import date, datetime
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Type, List
|
from typing import Type, List
|
||||||
from datetime import date, timezone
|
from datetime import date, timezone
|
||||||
|
import random
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.contrib.sessions.backends.db import SessionStore
|
from django.contrib.sessions.backends.db import SessionStore
|
||||||
|
@ -339,6 +340,26 @@ class ConversationAdapters:
|
||||||
async def get_openai_chat_config():
|
async def get_openai_chat_config():
|
||||||
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_valid_conversation_config(user: KhojUser):
|
||||||
|
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()
|
||||||
|
conversation_config = ConversationAdapters.get_conversation_config(user)
|
||||||
|
if conversation_config is None:
|
||||||
|
conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||||
|
|
||||||
|
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
|
||||||
|
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
||||||
|
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
||||||
|
|
||||||
|
return conversation_config
|
||||||
|
|
||||||
|
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
||||||
|
if openai_chat_config and conversation_config.model_type == "openai":
|
||||||
|
return conversation_config
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
|
||||||
|
|
||||||
|
|
||||||
class EntryAdapters:
|
class EntryAdapters:
|
||||||
word_filer = WordFilter()
|
word_filer = WordFilter()
|
||||||
|
|
|
@ -137,30 +137,33 @@
|
||||||
|
|
||||||
function processOnlineReferences(referenceSection, onlineContext) {
|
function processOnlineReferences(referenceSection, onlineContext) {
|
||||||
let numOnlineReferences = 0;
|
let numOnlineReferences = 0;
|
||||||
if (onlineContext.organic && onlineContext.organic.length > 0) {
|
for (let subquery in onlineContext) {
|
||||||
numOnlineReferences += onlineContext.organic.length;
|
let onlineReference = onlineContext[subquery];
|
||||||
for (let index in onlineContext.organic) {
|
if (onlineReference.organic && onlineReference.organic.length > 0) {
|
||||||
let reference = onlineContext.organic[index];
|
numOnlineReferences += onlineReference.organic.length;
|
||||||
let polishedReference = generateOnlineReference(reference, index);
|
for (let index in onlineReference.organic) {
|
||||||
referenceSection.appendChild(polishedReference);
|
let reference = onlineReference.organic[index];
|
||||||
|
let polishedReference = generateOnlineReference(reference, index);
|
||||||
|
referenceSection.appendChild(polishedReference);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (onlineContext.knowledgeGraph && onlineContext.knowledgeGraph.length > 0) {
|
if (onlineReference.knowledgeGraph && onlineReference.knowledgeGraph.length > 0) {
|
||||||
numOnlineReferences += onlineContext.knowledgeGraph.length;
|
numOnlineReferences += onlineReference.knowledgeGraph.length;
|
||||||
for (let index in onlineContext.knowledgeGraph) {
|
for (let index in onlineReference.knowledgeGraph) {
|
||||||
let reference = onlineContext.knowledgeGraph[index];
|
let reference = onlineReference.knowledgeGraph[index];
|
||||||
let polishedReference = generateOnlineReference(reference, index);
|
let polishedReference = generateOnlineReference(reference, index);
|
||||||
referenceSection.appendChild(polishedReference);
|
referenceSection.appendChild(polishedReference);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (onlineContext.peopleAlsoAsk && onlineContext.peopleAlsoAsk.length > 0) {
|
if (onlineReference.peopleAlsoAsk && onlineReference.peopleAlsoAsk.length > 0) {
|
||||||
numOnlineReferences += onlineContext.peopleAlsoAsk.length;
|
numOnlineReferences += onlineReference.peopleAlsoAsk.length;
|
||||||
for (let index in onlineContext.peopleAlsoAsk) {
|
for (let index in onlineReference.peopleAlsoAsk) {
|
||||||
let reference = onlineContext.peopleAlsoAsk[index];
|
let reference = onlineReference.peopleAlsoAsk[index];
|
||||||
let polishedReference = generateOnlineReference(reference, index);
|
let polishedReference = generateOnlineReference(reference, index);
|
||||||
referenceSection.appendChild(polishedReference);
|
referenceSection.appendChild(polishedReference);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -316,15 +319,28 @@
|
||||||
|
|
||||||
let referenceExpandButton = document.createElement('button');
|
let referenceExpandButton = document.createElement('button');
|
||||||
referenceExpandButton.classList.add("reference-expand-button");
|
referenceExpandButton.classList.add("reference-expand-button");
|
||||||
let expandButtonText = rawReferenceAsJson.length == 1 ? "1 reference" : `${rawReferenceAsJson.length} references`;
|
|
||||||
referenceExpandButton.innerHTML = expandButtonText;
|
|
||||||
|
|
||||||
references.appendChild(referenceExpandButton);
|
|
||||||
|
|
||||||
let referenceSection = document.createElement('div');
|
let referenceSection = document.createElement('div');
|
||||||
referenceSection.classList.add("reference-section");
|
referenceSection.classList.add("reference-section");
|
||||||
referenceSection.classList.add("collapsed");
|
referenceSection.classList.add("collapsed");
|
||||||
|
|
||||||
|
let numReferences = 0;
|
||||||
|
|
||||||
|
// If rawReferenceAsJson is a list, then count the length
|
||||||
|
if (Array.isArray(rawReferenceAsJson)) {
|
||||||
|
numReferences = rawReferenceAsJson.length;
|
||||||
|
|
||||||
|
rawReferenceAsJson.forEach((reference, index) => {
|
||||||
|
let polishedReference = generateReference(reference, index);
|
||||||
|
referenceSection.appendChild(polishedReference);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
|
||||||
|
}
|
||||||
|
|
||||||
|
references.appendChild(referenceExpandButton);
|
||||||
|
|
||||||
referenceExpandButton.addEventListener('click', function() {
|
referenceExpandButton.addEventListener('click', function() {
|
||||||
if (referenceSection.classList.contains("collapsed")) {
|
if (referenceSection.classList.contains("collapsed")) {
|
||||||
referenceSection.classList.remove("collapsed");
|
referenceSection.classList.remove("collapsed");
|
||||||
|
@ -335,10 +351,8 @@
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
rawReferenceAsJson.forEach((reference, index) => {
|
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
||||||
let polishedReference = generateReference(reference, index);
|
referenceExpandButton.innerHTML = expandButtonText;
|
||||||
referenceSection.appendChild(polishedReference);
|
|
||||||
});
|
|
||||||
references.appendChild(referenceSection);
|
references.appendChild(referenceSection);
|
||||||
readStream();
|
readStream();
|
||||||
} else {
|
} else {
|
||||||
|
@ -419,7 +433,7 @@
|
||||||
const khojToken = await window.tokenAPI.getToken();
|
const khojToken = await window.tokenAPI.getToken();
|
||||||
const headers = { 'Authorization': `Bearer ${khojToken}` };
|
const headers = { 'Authorization': `Bearer ${khojToken}` };
|
||||||
|
|
||||||
fetch(`${hostURL}/api/chat/history?client=web`, { headers })
|
fetch(`${hostURL}/api/chat/history?client=desktop`, { headers })
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
if (data.detail) {
|
if (data.detail) {
|
||||||
|
|
|
@ -147,35 +147,38 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
|
|
||||||
function processOnlineReferences(referenceSection, onlineContext) {
|
function processOnlineReferences(referenceSection, onlineContext) {
|
||||||
let numOnlineReferences = 0;
|
let numOnlineReferences = 0;
|
||||||
if (onlineContext.organic && onlineContext.organic.length > 0) {
|
for (let subquery in onlineContext) {
|
||||||
numOnlineReferences += onlineContext.organic.length;
|
let onlineReference = onlineContext[subquery];
|
||||||
for (let index in onlineContext.organic) {
|
if (onlineReference.organic && onlineReference.organic.length > 0) {
|
||||||
let reference = onlineContext.organic[index];
|
numOnlineReferences += onlineReference.organic.length;
|
||||||
|
for (let index in onlineReference.organic) {
|
||||||
|
let reference = onlineReference.organic[index];
|
||||||
let polishedReference = generateOnlineReference(reference, index);
|
let polishedReference = generateOnlineReference(reference, index);
|
||||||
referenceSection.appendChild(polishedReference);
|
referenceSection.appendChild(polishedReference);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (onlineContext.knowledgeGraph && onlineContext.knowledgeGraph.length > 0) {
|
if (onlineReference.knowledgeGraph && onlineReference.knowledgeGraph.length > 0) {
|
||||||
numOnlineReferences += onlineContext.knowledgeGraph.length;
|
numOnlineReferences += onlineReference.knowledgeGraph.length;
|
||||||
for (let index in onlineContext.knowledgeGraph) {
|
for (let index in onlineReference.knowledgeGraph) {
|
||||||
let reference = onlineContext.knowledgeGraph[index];
|
let reference = onlineReference.knowledgeGraph[index];
|
||||||
let polishedReference = generateOnlineReference(reference, index);
|
let polishedReference = generateOnlineReference(reference, index);
|
||||||
referenceSection.appendChild(polishedReference);
|
referenceSection.appendChild(polishedReference);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (onlineContext.peopleAlsoAsk && onlineContext.peopleAlsoAsk.length > 0) {
|
if (onlineReference.peopleAlsoAsk && onlineReference.peopleAlsoAsk.length > 0) {
|
||||||
numOnlineReferences += onlineContext.peopleAlsoAsk.length;
|
numOnlineReferences += onlineReference.peopleAlsoAsk.length;
|
||||||
for (let index in onlineContext.peopleAlsoAsk) {
|
for (let index in onlineReference.peopleAlsoAsk) {
|
||||||
let reference = onlineContext.peopleAlsoAsk[index];
|
let reference = onlineReference.peopleAlsoAsk[index];
|
||||||
let polishedReference = generateOnlineReference(reference, index);
|
let polishedReference = generateOnlineReference(reference, index);
|
||||||
referenceSection.appendChild(polishedReference);
|
referenceSection.appendChild(polishedReference);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return numOnlineReferences;
|
return numOnlineReferences;
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
|
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
|
||||||
if (context == null && onlineContext == null) {
|
if (context == null && onlineContext == null) {
|
||||||
|
@ -356,7 +359,6 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
references = document.createElement('div');
|
references = document.createElement('div');
|
||||||
references.classList.add("references");
|
references.classList.add("references");
|
||||||
|
|
||||||
|
|
||||||
let referenceExpandButton = document.createElement('button');
|
let referenceExpandButton = document.createElement('button');
|
||||||
referenceExpandButton.classList.add("reference-expand-button");
|
referenceExpandButton.classList.add("reference-expand-button");
|
||||||
|
|
||||||
|
|
|
@ -179,13 +179,6 @@ def converse_offline(
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(g, messages: List[ChatMessage], model: Any):
|
def llm_thread(g, messages: List[ChatMessage], model: Any):
|
||||||
try:
|
|
||||||
from gpt4all import GPT4All
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
assert isinstance(model, GPT4All), "model should be of type GPT4All"
|
|
||||||
user_message = messages[-1]
|
user_message = messages[-1]
|
||||||
system_message = messages[0]
|
system_message = messages[0]
|
||||||
conversation_history = messages[1:-1]
|
conversation_history = messages[1:-1]
|
||||||
|
@ -204,7 +197,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
|
||||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||||
|
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=512)
|
response_iterator = send_message_to_model_offline(prompted_message, loaded_model=model, streaming=True)
|
||||||
try:
|
try:
|
||||||
for response in response_iterator:
|
for response in response_iterator:
|
||||||
if any(stop_word in response.strip() for stop_word in stop_words):
|
if any(stop_word in response.strip() for stop_word in stop_words):
|
||||||
|
@ -214,3 +207,18 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
|
||||||
finally:
|
finally:
|
||||||
state.chat_lock.release()
|
state.chat_lock.release()
|
||||||
g.close()
|
g.close()
|
||||||
|
|
||||||
|
|
||||||
|
def send_message_to_model_offline(
|
||||||
|
message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
from gpt4all import GPT4All
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
|
||||||
|
gpt4all_model = loaded_model or GPT4All(model)
|
||||||
|
|
||||||
|
return gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming)
|
||||||
|
|
|
@ -100,6 +100,27 @@ def extract_questions(
|
||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
|
||||||
|
def send_message_to_model(
|
||||||
|
message,
|
||||||
|
api_key,
|
||||||
|
model,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Send message to model
|
||||||
|
"""
|
||||||
|
messages = [ChatMessage(content=message, role="assistant")]
|
||||||
|
|
||||||
|
# Get Response from GPT
|
||||||
|
return completion_with_backoff(
|
||||||
|
messages=messages,
|
||||||
|
model_name=model,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=100,
|
||||||
|
model_kwargs={"stop": ["A: ", "\n"]},
|
||||||
|
openai_api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def converse(
|
def converse(
|
||||||
references,
|
references,
|
||||||
online_results,
|
online_results,
|
||||||
|
|
|
@ -121,6 +121,32 @@ Information from the internet: {online_results}
|
||||||
Query: {query}""".strip()
|
Query: {query}""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
online_search_conversation_subqueries = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
The user has a question which you can use the internet to respond to. Can you break down the question into subqueries to get the correct answer? Provide search queries as a JSON list of strings
|
||||||
|
|
||||||
|
Today's date in UTC: {current_date}
|
||||||
|
|
||||||
|
Here are some examples of questions and subqueries:
|
||||||
|
Q: What is the weather like in New York?
|
||||||
|
A: ["weather in new york"]
|
||||||
|
|
||||||
|
Q: What is the weather like in New York and San Francisco?
|
||||||
|
A: ["weather in new york", "weather in san francisco"]
|
||||||
|
|
||||||
|
Q: What is the latest news about Google stock?
|
||||||
|
A: ["google stock news"]
|
||||||
|
|
||||||
|
Q: When is the next lunar eclipse?
|
||||||
|
A: ["next lunar eclipse"]
|
||||||
|
|
||||||
|
Q: How many oranges would fit in NASA's Saturn V rocket?
|
||||||
|
A: ["volume of an orange", "volume of saturn v rocket"]
|
||||||
|
|
||||||
|
This is the user's query:
|
||||||
|
Q: {query}
|
||||||
|
A: """.strip()
|
||||||
|
)
|
||||||
|
|
||||||
## Summarize Notes
|
## Summarize Notes
|
||||||
## --
|
## --
|
||||||
|
|
|
@ -3,6 +3,8 @@ import json
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from khoj.routers.helpers import generate_online_subqueries
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
|
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
|
||||||
|
@ -10,29 +12,41 @@ SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
|
||||||
url = "https://google.serper.dev/search"
|
url = "https://google.serper.dev/search"
|
||||||
|
|
||||||
|
|
||||||
def search_with_google(query: str):
|
async def search_with_google(query: str):
|
||||||
|
def _search_with_google(subquery: str):
|
||||||
|
payload = json.dumps(
|
||||||
|
{
|
||||||
|
"q": subquery,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
|
||||||
|
|
||||||
|
response = requests.request("POST", url, headers=headers, data=payload)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(response.text)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
json_response = response.json()
|
||||||
|
sub_response_dict = {}
|
||||||
|
sub_response_dict["knowledgeGraph"] = json_response.get("knowledgeGraph", {})
|
||||||
|
sub_response_dict["organic"] = json_response.get("organic", [])
|
||||||
|
sub_response_dict["answerBox"] = json_response.get("answerBox", [])
|
||||||
|
sub_response_dict["peopleAlsoAsk"] = json_response.get("peopleAlsoAsk", [])
|
||||||
|
|
||||||
|
return sub_response_dict
|
||||||
|
|
||||||
if SERPER_DEV_API_KEY is None:
|
if SERPER_DEV_API_KEY is None:
|
||||||
raise ValueError("SERPER_DEV_API_KEY is not set")
|
raise ValueError("SERPER_DEV_API_KEY is not set")
|
||||||
|
|
||||||
payload = json.dumps(
|
# Breakdown the query into subqueries to get the correct answer
|
||||||
{
|
subqueries = await generate_online_subqueries(query)
|
||||||
"q": query,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
|
|
||||||
|
|
||||||
response = requests.request("POST", url, headers=headers, data=payload)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
logger.error(response.text)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
json_response = response.json()
|
|
||||||
|
|
||||||
response_dict = {}
|
response_dict = {}
|
||||||
response_dict["knowledgeGraph"] = json_response.get("knowledgeGraph", {})
|
|
||||||
response_dict["organic"] = json_response.get("organic", [])
|
for subquery in subqueries:
|
||||||
response_dict["answerBox"] = json_response.get("answerBox", [])
|
logger.info(f"Searching with Google for '{subquery}'")
|
||||||
response_dict["peopleAlsoAsk"] = json_response.get("peopleAlsoAsk", [])
|
response_dict[subquery] = _search_with_google(subquery)
|
||||||
|
|
||||||
return response_dict
|
return response_dict
|
||||||
|
|
|
@ -597,7 +597,7 @@ async def chat(
|
||||||
|
|
||||||
elif conversation_command == ConversationCommand.Online:
|
elif conversation_command == ConversationCommand.Online:
|
||||||
try:
|
try:
|
||||||
online_results = search_with_google(defiltered_query)
|
online_results = await search_with_google(defiltered_query)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),
|
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),
|
||||||
|
|
|
@ -6,8 +6,13 @@ from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from time import time
|
from time import time
|
||||||
|
import json
|
||||||
from typing import Annotated, Iterator, List, Optional, Union, Tuple, Dict, Any
|
from typing import Annotated, Iterator, List, Optional, Union, Tuple, Dict, Any
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from khoj.processor.conversation import prompts
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import HTTPException, Header, Request, Depends
|
from fastapi import HTTPException, Header, Request, Depends
|
||||||
|
|
||||||
|
@ -15,10 +20,10 @@ from fastapi import HTTPException, Header, Request, Depends
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.config import GPT4AllProcessorModel
|
from khoj.utils.config import GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||||
from khoj.processor.conversation.openai.gpt import converse
|
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
||||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline
|
||||||
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
|
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
|
||||||
from database.models import KhojUser, Subscription
|
from database.models import KhojUser, Subscription, ChatModelOptions
|
||||||
from database.adapters import ConversationAdapters
|
from database.adapters import ConversationAdapters
|
||||||
|
|
||||||
|
|
||||||
|
@ -114,6 +119,65 @@ async def agenerate_chat_response(*args):
|
||||||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_online_subqueries(q: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Generate subqueries from the given query
|
||||||
|
"""
|
||||||
|
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
||||||
|
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
|
||||||
|
current_date=utc_date,
|
||||||
|
query=q,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await send_message_to_model_wrapper(online_queries_prompt)
|
||||||
|
|
||||||
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
try:
|
||||||
|
response = response.strip()
|
||||||
|
response = json.loads(response)
|
||||||
|
response = [q.strip() for q in response if q.strip()]
|
||||||
|
if not isinstance(response, list) or not response or len(response) == 0:
|
||||||
|
logger.error(f"Invalid response for constructing subqueries: {response}")
|
||||||
|
return [q]
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Invalid response for constructing subqueries: {response}")
|
||||||
|
return [q]
|
||||||
|
|
||||||
|
|
||||||
|
async def send_message_to_model_wrapper(
|
||||||
|
message: str,
|
||||||
|
):
|
||||||
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
|
||||||
|
if conversation_config is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||||
|
|
||||||
|
if conversation_config.model_type == "offline":
|
||||||
|
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
||||||
|
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
||||||
|
|
||||||
|
loaded_model = state.gpt4all_processor_config.loaded_model
|
||||||
|
return send_message_to_model_offline(
|
||||||
|
message=message,
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
model=conversation_config.chat_model,
|
||||||
|
streaming=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif conversation_config.model_type == "openai":
|
||||||
|
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||||
|
api_key = openai_chat_config.api_key
|
||||||
|
chat_model = conversation_config.chat_model
|
||||||
|
return send_message_to_model(
|
||||||
|
message=message,
|
||||||
|
api_key=api_key,
|
||||||
|
model=chat_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_response(
|
def generate_chat_response(
|
||||||
q: str,
|
q: str,
|
||||||
meta_log: dict,
|
meta_log: dict,
|
||||||
|
@ -163,12 +227,8 @@ def generate_chat_response(
|
||||||
meta_log=meta_log,
|
meta_log=meta_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
|
||||||
conversation_config = ConversationAdapters.get_conversation_config(user)
|
if conversation_config.model_type == "offline":
|
||||||
if conversation_config is None:
|
|
||||||
conversation_config = ConversationAdapters.get_default_conversation_config()
|
|
||||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
|
||||||
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
|
|
||||||
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
||||||
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
||||||
|
|
||||||
|
@ -186,7 +246,8 @@ def generate_chat_response(
|
||||||
tokenizer_name=conversation_config.tokenizer,
|
tokenizer_name=conversation_config.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif openai_chat_config and conversation_config.model_type == "openai":
|
elif conversation_config.model_type == "openai":
|
||||||
|
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
chat_response = converse(
|
chat_response = converse(
|
||||||
|
|
Loading…
Add table
Reference in a new issue