mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Add subqueries for internet-connected search results and update client-side code accordingly
- Add a wrapper method to help make direct queries to the LLM and determine any intermediate responses needed for handling the request
This commit is contained in:
parent
b8e6883a81
commit
fee99779bf
9 changed files with 253 additions and 86 deletions
|
@ -4,6 +4,7 @@ from datetime import date, datetime
|
|||
import secrets
|
||||
from typing import Type, List
|
||||
from datetime import date, timezone
|
||||
import random
|
||||
|
||||
from django.db import models
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
|
@ -339,6 +340,26 @@ class ConversationAdapters:
|
|||
async def get_openai_chat_config():
|
||||
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
||||
|
||||
@staticmethod
|
||||
def get_valid_conversation_config(user: KhojUser):
|
||||
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()
|
||||
conversation_config = ConversationAdapters.get_conversation_config(user)
|
||||
if conversation_config is None:
|
||||
conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||
|
||||
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
|
||||
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
||||
|
||||
return conversation_config
|
||||
|
||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
||||
if openai_chat_config and conversation_config.model_type == "openai":
|
||||
return conversation_config
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
|
||||
|
||||
|
||||
class EntryAdapters:
|
||||
word_filer = WordFilter()
|
||||
|
|
|
@ -137,30 +137,33 @@
|
|||
|
||||
function processOnlineReferences(referenceSection, onlineContext) {
|
||||
let numOnlineReferences = 0;
|
||||
if (onlineContext.organic && onlineContext.organic.length > 0) {
|
||||
numOnlineReferences += onlineContext.organic.length;
|
||||
for (let index in onlineContext.organic) {
|
||||
let reference = onlineContext.organic[index];
|
||||
let polishedReference = generateOnlineReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
for (let subquery in onlineContext) {
|
||||
let onlineReference = onlineContext[subquery];
|
||||
if (onlineReference.organic && onlineReference.organic.length > 0) {
|
||||
numOnlineReferences += onlineReference.organic.length;
|
||||
for (let index in onlineReference.organic) {
|
||||
let reference = onlineReference.organic[index];
|
||||
let polishedReference = generateOnlineReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (onlineContext.knowledgeGraph && onlineContext.knowledgeGraph.length > 0) {
|
||||
numOnlineReferences += onlineContext.knowledgeGraph.length;
|
||||
for (let index in onlineContext.knowledgeGraph) {
|
||||
let reference = onlineContext.knowledgeGraph[index];
|
||||
let polishedReference = generateOnlineReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
if (onlineReference.knowledgeGraph && onlineReference.knowledgeGraph.length > 0) {
|
||||
numOnlineReferences += onlineReference.knowledgeGraph.length;
|
||||
for (let index in onlineReference.knowledgeGraph) {
|
||||
let reference = onlineReference.knowledgeGraph[index];
|
||||
let polishedReference = generateOnlineReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (onlineContext.peopleAlsoAsk && onlineContext.peopleAlsoAsk.length > 0) {
|
||||
numOnlineReferences += onlineContext.peopleAlsoAsk.length;
|
||||
for (let index in onlineContext.peopleAlsoAsk) {
|
||||
let reference = onlineContext.peopleAlsoAsk[index];
|
||||
let polishedReference = generateOnlineReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
if (onlineReference.peopleAlsoAsk && onlineReference.peopleAlsoAsk.length > 0) {
|
||||
numOnlineReferences += onlineReference.peopleAlsoAsk.length;
|
||||
for (let index in onlineReference.peopleAlsoAsk) {
|
||||
let reference = onlineReference.peopleAlsoAsk[index];
|
||||
let polishedReference = generateOnlineReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -316,15 +319,28 @@
|
|||
|
||||
let referenceExpandButton = document.createElement('button');
|
||||
referenceExpandButton.classList.add("reference-expand-button");
|
||||
let expandButtonText = rawReferenceAsJson.length == 1 ? "1 reference" : `${rawReferenceAsJson.length} references`;
|
||||
referenceExpandButton.innerHTML = expandButtonText;
|
||||
|
||||
references.appendChild(referenceExpandButton);
|
||||
|
||||
let referenceSection = document.createElement('div');
|
||||
referenceSection.classList.add("reference-section");
|
||||
referenceSection.classList.add("collapsed");
|
||||
|
||||
let numReferences = 0;
|
||||
|
||||
// If rawReferenceAsJson is a list, then count the length
|
||||
if (Array.isArray(rawReferenceAsJson)) {
|
||||
numReferences = rawReferenceAsJson.length;
|
||||
|
||||
rawReferenceAsJson.forEach((reference, index) => {
|
||||
let polishedReference = generateReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
});
|
||||
} else {
|
||||
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
|
||||
}
|
||||
|
||||
references.appendChild(referenceExpandButton);
|
||||
|
||||
referenceExpandButton.addEventListener('click', function() {
|
||||
if (referenceSection.classList.contains("collapsed")) {
|
||||
referenceSection.classList.remove("collapsed");
|
||||
|
@ -335,10 +351,8 @@
|
|||
}
|
||||
});
|
||||
|
||||
rawReferenceAsJson.forEach((reference, index) => {
|
||||
let polishedReference = generateReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
});
|
||||
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
|
||||
referenceExpandButton.innerHTML = expandButtonText;
|
||||
references.appendChild(referenceSection);
|
||||
readStream();
|
||||
} else {
|
||||
|
@ -419,7 +433,7 @@
|
|||
const khojToken = await window.tokenAPI.getToken();
|
||||
const headers = { 'Authorization': `Bearer ${khojToken}` };
|
||||
|
||||
fetch(`${hostURL}/api/chat/history?client=web`, { headers })
|
||||
fetch(`${hostURL}/api/chat/history?client=desktop`, { headers })
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.detail) {
|
||||
|
|
|
@ -147,35 +147,38 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
|
||||
function processOnlineReferences(referenceSection, onlineContext) {
|
||||
let numOnlineReferences = 0;
|
||||
if (onlineContext.organic && onlineContext.organic.length > 0) {
|
||||
numOnlineReferences += onlineContext.organic.length;
|
||||
for (let index in onlineContext.organic) {
|
||||
let reference = onlineContext.organic[index];
|
||||
for (let subquery in onlineContext) {
|
||||
let onlineReference = onlineContext[subquery];
|
||||
if (onlineReference.organic && onlineReference.organic.length > 0) {
|
||||
numOnlineReferences += onlineReference.organic.length;
|
||||
for (let index in onlineReference.organic) {
|
||||
let reference = onlineReference.organic[index];
|
||||
let polishedReference = generateOnlineReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (onlineContext.knowledgeGraph && onlineContext.knowledgeGraph.length > 0) {
|
||||
numOnlineReferences += onlineContext.knowledgeGraph.length;
|
||||
for (let index in onlineContext.knowledgeGraph) {
|
||||
let reference = onlineContext.knowledgeGraph[index];
|
||||
if (onlineReference.knowledgeGraph && onlineReference.knowledgeGraph.length > 0) {
|
||||
numOnlineReferences += onlineReference.knowledgeGraph.length;
|
||||
for (let index in onlineReference.knowledgeGraph) {
|
||||
let reference = onlineReference.knowledgeGraph[index];
|
||||
let polishedReference = generateOnlineReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (onlineContext.peopleAlsoAsk && onlineContext.peopleAlsoAsk.length > 0) {
|
||||
numOnlineReferences += onlineContext.peopleAlsoAsk.length;
|
||||
for (let index in onlineContext.peopleAlsoAsk) {
|
||||
let reference = onlineContext.peopleAlsoAsk[index];
|
||||
let polishedReference = generateOnlineReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
if (onlineReference.peopleAlsoAsk && onlineReference.peopleAlsoAsk.length > 0) {
|
||||
numOnlineReferences += onlineReference.peopleAlsoAsk.length;
|
||||
for (let index in onlineReference.peopleAlsoAsk) {
|
||||
let reference = onlineReference.peopleAlsoAsk[index];
|
||||
let polishedReference = generateOnlineReference(reference, index);
|
||||
referenceSection.appendChild(polishedReference);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return numOnlineReferences;
|
||||
}
|
||||
}
|
||||
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
|
||||
if (context == null && onlineContext == null) {
|
||||
|
@ -356,7 +359,6 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
references = document.createElement('div');
|
||||
references.classList.add("references");
|
||||
|
||||
|
||||
let referenceExpandButton = document.createElement('button');
|
||||
referenceExpandButton.classList.add("reference-expand-button");
|
||||
|
||||
|
|
|
@ -179,13 +179,6 @@ def converse_offline(
|
|||
|
||||
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any):
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError as e:
|
||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
raise e
|
||||
|
||||
assert isinstance(model, GPT4All), "model should be of type GPT4All"
|
||||
user_message = messages[-1]
|
||||
system_message = messages[0]
|
||||
conversation_history = messages[1:-1]
|
||||
|
@ -204,7 +197,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
|
|||
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||
|
||||
state.chat_lock.acquire()
|
||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=512)
|
||||
response_iterator = send_message_to_model_offline(prompted_message, loaded_model=model, streaming=True)
|
||||
try:
|
||||
for response in response_iterator:
|
||||
if any(stop_word in response.strip() for stop_word in stop_words):
|
||||
|
@ -214,3 +207,18 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
|
|||
finally:
|
||||
state.chat_lock.release()
|
||||
g.close()
|
||||
|
||||
|
||||
def send_message_to_model_offline(
|
||||
message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False
|
||||
):
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError as e:
|
||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
raise e
|
||||
|
||||
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
|
||||
return gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming)
|
||||
|
|
|
@ -100,6 +100,27 @@ def extract_questions(
|
|||
return questions
|
||||
|
||||
|
||||
def send_message_to_model(
|
||||
message,
|
||||
api_key,
|
||||
model,
|
||||
):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
messages = [ChatMessage(content=message, role="assistant")]
|
||||
|
||||
# Get Response from GPT
|
||||
return completion_with_backoff(
|
||||
messages=messages,
|
||||
model_name=model,
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
model_kwargs={"stop": ["A: ", "\n"]},
|
||||
openai_api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
def converse(
|
||||
references,
|
||||
online_results,
|
||||
|
|
|
@ -121,6 +121,32 @@ Information from the internet: {online_results}
|
|||
Query: {query}""".strip()
|
||||
)
|
||||
|
||||
online_search_conversation_subqueries = PromptTemplate.from_template(
|
||||
"""
|
||||
The user has a question which you can use the internet to respond to. Can you break down the question into subqueries to get the correct answer? Provide search queries as a JSON list of strings
|
||||
|
||||
Today's date in UTC: {current_date}
|
||||
|
||||
Here are some examples of questions and subqueries:
|
||||
Q: What is the weather like in New York?
|
||||
A: ["weather in new york"]
|
||||
|
||||
Q: What is the weather like in New York and San Francisco?
|
||||
A: ["weather in new york", "weather in san francisco"]
|
||||
|
||||
Q: What is the latest news about Google stock?
|
||||
A: ["google stock news"]
|
||||
|
||||
Q: When is the next lunar eclipse?
|
||||
A: ["next lunar eclipse"]
|
||||
|
||||
Q: How many oranges would fit in NASA's Saturn V rocket?
|
||||
A: ["volume of an orange", "volume of saturn v rocket"]
|
||||
|
||||
This is the user's query:
|
||||
Q: {query}
|
||||
A: """.strip()
|
||||
)
|
||||
|
||||
## Summarize Notes
|
||||
## --
|
||||
|
|
|
@ -3,6 +3,8 @@ import json
|
|||
import os
|
||||
import logging
|
||||
|
||||
from khoj.routers.helpers import generate_online_subqueries
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
|
||||
|
@ -10,29 +12,41 @@ SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
|
|||
url = "https://google.serper.dev/search"
|
||||
|
||||
|
||||
def search_with_google(query: str):
|
||||
async def search_with_google(query: str):
|
||||
def _search_with_google(subquery: str):
|
||||
payload = json.dumps(
|
||||
{
|
||||
"q": subquery,
|
||||
}
|
||||
)
|
||||
|
||||
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(response.text)
|
||||
return {}
|
||||
|
||||
json_response = response.json()
|
||||
sub_response_dict = {}
|
||||
sub_response_dict["knowledgeGraph"] = json_response.get("knowledgeGraph", {})
|
||||
sub_response_dict["organic"] = json_response.get("organic", [])
|
||||
sub_response_dict["answerBox"] = json_response.get("answerBox", [])
|
||||
sub_response_dict["peopleAlsoAsk"] = json_response.get("peopleAlsoAsk", [])
|
||||
|
||||
return sub_response_dict
|
||||
|
||||
if SERPER_DEV_API_KEY is None:
|
||||
raise ValueError("SERPER_DEV_API_KEY is not set")
|
||||
|
||||
payload = json.dumps(
|
||||
{
|
||||
"q": query,
|
||||
}
|
||||
)
|
||||
|
||||
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(response.text)
|
||||
return {}
|
||||
|
||||
json_response = response.json()
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(query)
|
||||
|
||||
response_dict = {}
|
||||
response_dict["knowledgeGraph"] = json_response.get("knowledgeGraph", {})
|
||||
response_dict["organic"] = json_response.get("organic", [])
|
||||
response_dict["answerBox"] = json_response.get("answerBox", [])
|
||||
response_dict["peopleAlsoAsk"] = json_response.get("peopleAlsoAsk", [])
|
||||
|
||||
for subquery in subqueries:
|
||||
logger.info(f"Searching with Google for '{subquery}'")
|
||||
response_dict[subquery] = _search_with_google(subquery)
|
||||
|
||||
return response_dict
|
||||
|
|
|
@ -597,7 +597,7 @@ async def chat(
|
|||
|
||||
elif conversation_command == ConversationCommand.Online:
|
||||
try:
|
||||
online_results = search_with_google(defiltered_query)
|
||||
online_results = await search_with_google(defiltered_query)
|
||||
except ValueError as e:
|
||||
return StreamingResponse(
|
||||
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),
|
||||
|
|
|
@ -6,8 +6,13 @@ from datetime import datetime
|
|||
from functools import partial
|
||||
import logging
|
||||
from time import time
|
||||
import json
|
||||
from typing import Annotated, Iterator, List, Optional, Union, Tuple, Dict, Any
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from khoj.processor.conversation import prompts
|
||||
|
||||
# External Packages
|
||||
from fastapi import HTTPException, Header, Request, Depends
|
||||
|
||||
|
@ -15,10 +20,10 @@ from fastapi import HTTPException, Header, Request, Depends
|
|||
from khoj.utils import state
|
||||
from khoj.utils.config import GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||
from khoj.processor.conversation.openai.gpt import converse
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
||||
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline
|
||||
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
|
||||
from database.models import KhojUser, Subscription
|
||||
from database.models import KhojUser, Subscription, ChatModelOptions
|
||||
from database.adapters import ConversationAdapters
|
||||
|
||||
|
||||
|
@ -114,6 +119,65 @@ async def agenerate_chat_response(*args):
|
|||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
||||
|
||||
|
||||
async def generate_online_subqueries(q: str) -> List[str]:
|
||||
"""
|
||||
Generate subqueries from the given query
|
||||
"""
|
||||
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
|
||||
current_date=utc_date,
|
||||
query=q,
|
||||
)
|
||||
|
||||
response = await send_message_to_model_wrapper(online_queries_prompt)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
try:
|
||||
response = response.strip()
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response if q.strip()]
|
||||
if not isinstance(response, list) or not response or len(response) == 0:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}")
|
||||
return [q]
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}")
|
||||
return [q]
|
||||
|
||||
|
||||
async def send_message_to_model_wrapper(
|
||||
message: str,
|
||||
):
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
|
||||
if conversation_config is None:
|
||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||
|
||||
if conversation_config.model_type == "offline":
|
||||
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
||||
|
||||
loaded_model = state.gpt4all_processor_config.loaded_model
|
||||
return send_message_to_model_offline(
|
||||
message=message,
|
||||
loaded_model=loaded_model,
|
||||
model=conversation_config.chat_model,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == "openai":
|
||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model = conversation_config.chat_model
|
||||
return send_message_to_model(
|
||||
message=message,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
||||
|
||||
def generate_chat_response(
|
||||
q: str,
|
||||
meta_log: dict,
|
||||
|
@ -163,12 +227,8 @@ def generate_chat_response(
|
|||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()
|
||||
conversation_config = ConversationAdapters.get_conversation_config(user)
|
||||
if conversation_config is None:
|
||||
conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
||||
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
|
||||
if conversation_config.model_type == "offline":
|
||||
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
||||
|
||||
|
@ -186,7 +246,8 @@ def generate_chat_response(
|
|||
tokenizer_name=conversation_config.tokenizer,
|
||||
)
|
||||
|
||||
elif openai_chat_config and conversation_config.model_type == "openai":
|
||||
elif conversation_config.model_type == "openai":
|
||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model = conversation_config.chat_model
|
||||
chat_response = converse(
|
||||
|
|
Loading…
Reference in a new issue