Add subqueries for internet-connected search results and update client-side code accordingly

- Add a wrapper method to help make direct queries to the LLM and determine any intermediate responses needed for handling the request
This commit is contained in:
sabaimran 2023-11-20 15:19:15 -08:00
parent b8e6883a81
commit fee99779bf
9 changed files with 253 additions and 86 deletions

View file

@ -4,6 +4,7 @@ from datetime import date, datetime
import secrets
from typing import Type, List
from datetime import date, timezone
import random
from django.db import models
from django.contrib.sessions.backends.db import SessionStore
@ -339,6 +340,26 @@ class ConversationAdapters:
async def get_openai_chat_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
@staticmethod
def get_valid_conversation_config(user: KhojUser):
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()
conversation_config = ConversationAdapters.get_conversation_config(user)
if conversation_config is None:
conversation_config = ConversationAdapters.get_default_conversation_config()
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
return conversation_config
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
if openai_chat_config and conversation_config.model_type == "openai":
return conversation_config
else:
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
class EntryAdapters:
word_filer = WordFilter()

View file

@ -137,30 +137,33 @@
function processOnlineReferences(referenceSection, onlineContext) {
let numOnlineReferences = 0;
if (onlineContext.organic && onlineContext.organic.length > 0) {
numOnlineReferences += onlineContext.organic.length;
for (let index in onlineContext.organic) {
let reference = onlineContext.organic[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
for (let subquery in onlineContext) {
let onlineReference = onlineContext[subquery];
if (onlineReference.organic && onlineReference.organic.length > 0) {
numOnlineReferences += onlineReference.organic.length;
for (let index in onlineReference.organic) {
let reference = onlineReference.organic[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
}
if (onlineContext.knowledgeGraph && onlineContext.knowledgeGraph.length > 0) {
numOnlineReferences += onlineContext.knowledgeGraph.length;
for (let index in onlineContext.knowledgeGraph) {
let reference = onlineContext.knowledgeGraph[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
if (onlineReference.knowledgeGraph && onlineReference.knowledgeGraph.length > 0) {
numOnlineReferences += onlineReference.knowledgeGraph.length;
for (let index in onlineReference.knowledgeGraph) {
let reference = onlineReference.knowledgeGraph[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
}
if (onlineContext.peopleAlsoAsk && onlineContext.peopleAlsoAsk.length > 0) {
numOnlineReferences += onlineContext.peopleAlsoAsk.length;
for (let index in onlineContext.peopleAlsoAsk) {
let reference = onlineContext.peopleAlsoAsk[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
if (onlineReference.peopleAlsoAsk && onlineReference.peopleAlsoAsk.length > 0) {
numOnlineReferences += onlineReference.peopleAlsoAsk.length;
for (let index in onlineReference.peopleAlsoAsk) {
let reference = onlineReference.peopleAlsoAsk[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
}
@ -316,15 +319,28 @@
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let expandButtonText = rawReferenceAsJson.length == 1 ? "1 reference" : `${rawReferenceAsJson.length} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceExpandButton);
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
@ -335,10 +351,8 @@
}
});
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
readStream();
} else {
@ -419,7 +433,7 @@
const khojToken = await window.tokenAPI.getToken();
const headers = { 'Authorization': `Bearer ${khojToken}` };
fetch(`${hostURL}/api/chat/history?client=web`, { headers })
fetch(`${hostURL}/api/chat/history?client=desktop`, { headers })
.then(response => response.json())
.then(data => {
if (data.detail) {

View file

@ -147,35 +147,38 @@ To get started, just start typing below. You can also type / to see a list of co
function processOnlineReferences(referenceSection, onlineContext) {
let numOnlineReferences = 0;
if (onlineContext.organic && onlineContext.organic.length > 0) {
numOnlineReferences += onlineContext.organic.length;
for (let index in onlineContext.organic) {
let reference = onlineContext.organic[index];
for (let subquery in onlineContext) {
let onlineReference = onlineContext[subquery];
if (onlineReference.organic && onlineReference.organic.length > 0) {
numOnlineReferences += onlineReference.organic.length;
for (let index in onlineReference.organic) {
let reference = onlineReference.organic[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
}
if (onlineContext.knowledgeGraph && onlineContext.knowledgeGraph.length > 0) {
numOnlineReferences += onlineContext.knowledgeGraph.length;
for (let index in onlineContext.knowledgeGraph) {
let reference = onlineContext.knowledgeGraph[index];
if (onlineReference.knowledgeGraph && onlineReference.knowledgeGraph.length > 0) {
numOnlineReferences += onlineReference.knowledgeGraph.length;
for (let index in onlineReference.knowledgeGraph) {
let reference = onlineReference.knowledgeGraph[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
}
if (onlineContext.peopleAlsoAsk && onlineContext.peopleAlsoAsk.length > 0) {
numOnlineReferences += onlineContext.peopleAlsoAsk.length;
for (let index in onlineContext.peopleAlsoAsk) {
let reference = onlineContext.peopleAlsoAsk[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
if (onlineReference.peopleAlsoAsk && onlineReference.peopleAlsoAsk.length > 0) {
numOnlineReferences += onlineReference.peopleAlsoAsk.length;
for (let index in onlineReference.peopleAlsoAsk) {
let reference = onlineReference.peopleAlsoAsk[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
}
}
return numOnlineReferences;
}
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
if (context == null && onlineContext == null) {
@ -356,7 +359,6 @@ To get started, just start typing below. You can also type / to see a list of co
references = document.createElement('div');
references.classList.add("references");
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");

View file

@ -179,13 +179,6 @@ def converse_offline(
def llm_thread(g, messages: List[ChatMessage], model: Any):
try:
from gpt4all import GPT4All
except ModuleNotFoundError as e:
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
assert isinstance(model, GPT4All), "model should be of type GPT4All"
user_message = messages[-1]
system_message = messages[0]
conversation_history = messages[1:-1]
@ -204,7 +197,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
prompted_message = templated_system_message + chat_history + templated_user_message
state.chat_lock.acquire()
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=512)
response_iterator = send_message_to_model_offline(prompted_message, loaded_model=model, streaming=True)
try:
for response in response_iterator:
if any(stop_word in response.strip() for stop_word in stop_words):
@ -214,3 +207,18 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
finally:
state.chat_lock.release()
g.close()
def send_message_to_model_offline(
message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False
):
try:
from gpt4all import GPT4All
except ModuleNotFoundError as e:
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
gpt4all_model = loaded_model or GPT4All(model)
return gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming)

View file

@ -100,6 +100,27 @@ def extract_questions(
return questions
def send_message_to_model(
message,
api_key,
model,
):
"""
Send message to model
"""
messages = [ChatMessage(content=message, role="assistant")]
# Get Response from GPT
return completion_with_backoff(
messages=messages,
model_name=model,
temperature=0,
max_tokens=100,
model_kwargs={"stop": ["A: ", "\n"]},
openai_api_key=api_key,
)
def converse(
references,
online_results,

View file

@ -121,6 +121,32 @@ Information from the internet: {online_results}
Query: {query}""".strip()
)
online_search_conversation_subqueries = PromptTemplate.from_template(
"""
The user has a question which you can use the internet to respond to. Can you break down the question into subqueries to get the correct answer? Provide search queries as a JSON list of strings
Today's date in UTC: {current_date}
Here are some examples of questions and subqueries:
Q: What is the weather like in New York?
A: ["weather in new york"]
Q: What is the weather like in New York and San Francisco?
A: ["weather in new york", "weather in san francisco"]
Q: What is the latest news about Google stock?
A: ["google stock news"]
Q: When is the next lunar eclipse?
A: ["next lunar eclipse"]
Q: How many oranges would fit in NASA's Saturn V rocket?
A: ["volume of an orange", "volume of saturn v rocket"]
This is the user's query:
Q: {query}
A: """.strip()
)
## Summarize Notes
## --

View file

@ -3,6 +3,8 @@ import json
import os
import logging
from khoj.routers.helpers import generate_online_subqueries
logger = logging.getLogger(__name__)
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
@ -10,29 +12,41 @@ SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
url = "https://google.serper.dev/search"
def search_with_google(query: str):
async def search_with_google(query: str):
def _search_with_google(subquery: str):
payload = json.dumps(
{
"q": subquery,
}
)
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code != 200:
logger.error(response.text)
return {}
json_response = response.json()
sub_response_dict = {}
sub_response_dict["knowledgeGraph"] = json_response.get("knowledgeGraph", {})
sub_response_dict["organic"] = json_response.get("organic", [])
sub_response_dict["answerBox"] = json_response.get("answerBox", [])
sub_response_dict["peopleAlsoAsk"] = json_response.get("peopleAlsoAsk", [])
return sub_response_dict
if SERPER_DEV_API_KEY is None:
raise ValueError("SERPER_DEV_API_KEY is not set")
payload = json.dumps(
{
"q": query,
}
)
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code != 200:
logger.error(response.text)
return {}
json_response = response.json()
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(query)
response_dict = {}
response_dict["knowledgeGraph"] = json_response.get("knowledgeGraph", {})
response_dict["organic"] = json_response.get("organic", [])
response_dict["answerBox"] = json_response.get("answerBox", [])
response_dict["peopleAlsoAsk"] = json_response.get("peopleAlsoAsk", [])
for subquery in subqueries:
logger.info(f"Searching with Google for '{subquery}'")
response_dict[subquery] = _search_with_google(subquery)
return response_dict

View file

@ -597,7 +597,7 @@ async def chat(
elif conversation_command == ConversationCommand.Online:
try:
online_results = search_with_google(defiltered_query)
online_results = await search_with_google(defiltered_query)
except ValueError as e:
return StreamingResponse(
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),

View file

@ -6,8 +6,13 @@ from datetime import datetime
from functools import partial
import logging
from time import time
import json
from typing import Annotated, Iterator, List, Optional, Union, Tuple, Dict, Any
from datetime import datetime
from khoj.processor.conversation import prompts
# External Packages
from fastapi import HTTPException, Header, Request, Depends
@ -15,10 +20,10 @@ from fastapi import HTTPException, Header, Request, Depends
from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry
from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
from database.models import KhojUser, Subscription
from database.models import KhojUser, Subscription, ChatModelOptions
from database.adapters import ConversationAdapters
@ -114,6 +119,65 @@ async def agenerate_chat_response(*args):
return await loop.run_in_executor(executor, generate_chat_response, *args)
async def generate_online_subqueries(q: str) -> List[str]:
"""
Generate subqueries from the given query
"""
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
current_date=utc_date,
query=q,
)
response = await send_message_to_model_wrapper(online_queries_prompt)
# Validate that the response is a non-empty, JSON-serializable list
try:
response = response.strip()
response = json.loads(response)
response = [q.strip() for q in response if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [q]
return response
except Exception as e:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [q]
async def send_message_to_model_wrapper(
message: str,
):
conversation_config = await ConversationAdapters.aget_default_conversation_config()
if conversation_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
if conversation_config.model_type == "offline":
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
loaded_model = state.gpt4all_processor_config.loaded_model
return send_message_to_model_offline(
message=message,
loaded_model=loaded_model,
model=conversation_config.chat_model,
streaming=False,
)
elif conversation_config.model_type == "openai":
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model
return send_message_to_model(
message=message,
api_key=api_key,
model=chat_model,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
def generate_chat_response(
q: str,
meta_log: dict,
@ -163,12 +227,8 @@ def generate_chat_response(
meta_log=meta_log,
)
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()
conversation_config = ConversationAdapters.get_conversation_config(user)
if conversation_config is None:
conversation_config = ConversationAdapters.get_default_conversation_config()
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
if conversation_config.model_type == "offline":
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
@ -186,7 +246,8 @@ def generate_chat_response(
tokenizer_name=conversation_config.tokenizer,
)
elif openai_chat_config and conversation_config.model_type == "openai":
elif conversation_config.model_type == "openai":
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model
chat_response = converse(