Add support for using serper.dev for online queries

- Use the knowledgeGraph, answerBox, peopleAlsoAsk and organic responses of serper.dev to provide online context for queries made with the /online command
- Add it as an additional tool for doing Google searches
- Render the results appropriately in the chat web window
- Pass appropriate reference data down to the LLM
This commit is contained in:
sabaimran 2023-11-17 16:19:11 -08:00
parent bfbe273ffd
commit 0fcf234f07
11 changed files with 250 additions and 21 deletions

View file

@ -70,6 +70,52 @@ To get started, just start typing below. You can also type / to see a list of co
return referenceButton;
}
function generateOnlineReference(reference, index) {
// Generate HTML for Chat Reference
let title = reference.title;
let link = reference.link;
let snippet = reference.snippet;
let question = reference.question;
if (question) {
question = `<b>Question:</b> ${question}<br><br>`;
} else {
question = "";
}
let linkElement = document.createElement('a');
linkElement.setAttribute('href', link);
linkElement.setAttribute('target', '_blank');
linkElement.setAttribute('rel', 'noopener noreferrer');
linkElement.classList.add("inline-chat-link");
linkElement.classList.add("reference-link");
linkElement.setAttribute('title', title);
linkElement.innerHTML = title;
let referenceButton = document.createElement('button');
referenceButton.innerHTML = linkElement.outerHTML;
referenceButton.id = `ref-${index}`;
referenceButton.classList.add("reference-button");
referenceButton.classList.add("collapsed");
referenceButton.tabIndex = 0;
// Add event listener to toggle full reference on click
referenceButton.addEventListener('click', function() {
console.log(`Toggling ref-${index}`)
if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed");
this.classList.add("expanded");
this.innerHTML = linkElement.outerHTML + `<br><br>${question + snippet}`;
} else {
this.classList.add("collapsed");
this.classList.remove("expanded");
this.innerHTML = linkElement.outerHTML;
}
});
return referenceButton;
}
function renderMessage(message, by, dt=null, annotations=null) {
let message_time = formatDate(dt ?? new Date());
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
@ -99,8 +145,45 @@ To get started, just start typing below. You can also type / to see a list of co
chatBody.scrollTop = chatBody.scrollHeight;
}
function renderMessageWithReference(message, by, context=null, dt=null) {
if (context == null || context.length == 0) {
function processOnlineReferences(referenceSection, onlineContext) {
let numOnlineReferences = 0;
if (onlineContext.organic && onlineContext.organic.length > 0) {
numOnlineReferences += onlineContext.organic.length;
for (let index in onlineContext.organic) {
let reference = onlineContext.organic[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
if (onlineContext.knowledgeGraph && onlineContext.knowledgeGraph.length > 0) {
numOnlineReferences += onlineContext.knowledgeGraph.length;
for (let index in onlineContext.knowledgeGraph) {
let reference = onlineContext.knowledgeGraph[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
if (onlineContext.peopleAlsoAsk && onlineContext.peopleAlsoAsk.length > 0) {
numOnlineReferences += onlineContext.peopleAlsoAsk.length;
for (let index in onlineContext.peopleAlsoAsk) {
let reference = onlineContext.peopleAlsoAsk[index];
let polishedReference = generateOnlineReference(reference, index);
referenceSection.appendChild(polishedReference);
}
}
return numOnlineReferences;
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
}
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && onlineContext.organic.length == 0 && onlineContext.knowledgeGraph.length == 0 && onlineContext.peopleAlsoAsk.length == 0))) {
renderMessage(message, by, dt);
return;
}
@ -109,8 +192,11 @@ To get started, just start typing below. You can also type / to see a list of co
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let expandButtonText = context.length == 1 ? "1 reference" : `${context.length} references`;
referenceExpandButton.innerHTML = expandButtonText;
let numReferences = 0;
if (context) {
numReferences += context.length;
}
references.appendChild(referenceExpandButton);
@ -136,6 +222,14 @@ To get started, just start typing below. You can also type / to see a list of co
referenceSection.appendChild(polishedReference);
}
}
if (onlineContext) {
numReferences += processOnlineReferences(referenceSection, onlineContext);
}
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
renderMessage(message, by, dt, references);
@ -177,6 +271,13 @@ To get started, just start typing below. You can also type / to see a list of co
// Add the class "chat-response" to each element
codeElement.classList.add("chat-response");
});
let anchorElements = element.querySelectorAll('a');
anchorElements.forEach((anchorElement) => {
// Add the class "inline-chat-link" to each element
anchorElement.classList.add("inline-chat-link");
});
return element
}
@ -258,15 +359,28 @@ To get started, just start typing below. You can also type / to see a list of co
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
let expandButtonText = rawReferenceAsJson.length == 1 ? "1 reference" : `${rawReferenceAsJson.length} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceExpandButton);
let referenceSection = document.createElement('div');
referenceSection.classList.add("reference-section");
referenceSection.classList.add("collapsed");
let numReferences = 0;
// If rawReferenceAsJson is a list, then count the length
if (Array.isArray(rawReferenceAsJson)) {
numReferences = rawReferenceAsJson.length;
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
} else {
numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
references.appendChild(referenceExpandButton);
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
referenceSection.classList.remove("collapsed");
@ -277,10 +391,8 @@ To get started, just start typing below. You can also type / to see a list of co
}
});
rawReferenceAsJson.forEach((reference, index) => {
let polishedReference = generateReference(reference, index);
referenceSection.appendChild(polishedReference);
});
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
readStream();
} else {
@ -371,7 +483,7 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => {
// Render conversation history, if any
response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created));
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
});
})
.catch(err => {
@ -665,6 +777,11 @@ To get started, just start typing below. You can also type / to see a list of co
border-bottom: 1px dotted var(--main-text-color);
}
a.reference-link {
color: var(--main-text-color);
border-bottom: 1px dotted var(--main-text-color);
}
button.copy-button {
display: block;
border-radius: 4px;
@ -749,6 +866,10 @@ To get started, just start typing below. You can also type / to see a list of co
padding: 0px;
}
p {
margin: 0;
}
div.programmatic-output {
background-color: #f5f5f5;
border: 1px solid #ddd;

View file

@ -121,6 +121,7 @@ def filter_questions(questions: List[str]):
def converse_offline(
references,
online_results,
user_query,
conversation_log={},
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
@ -147,6 +148,13 @@ def converse_offline(
# Get Conversation Primer appropriate to Conversation Type
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references_message):
return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
elif conversation_command == ConversationCommand.Online:
conversation_primer = prompts.online_search_conversation.format(
query=user_query, online_results=str(online_results)
)
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
conversation_primer = user_query
else:
@ -164,7 +172,7 @@ def converse_offline(
tokenizer_name=tokenizer_name,
)
g = ThreadedGenerator(references, completion_func=completion_func)
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, gpt4all_model))
t.start()
return g

View file

@ -88,6 +88,7 @@ def extract_questions(
def converse(
references,
online_results,
user_query,
conversation_log={},
model: str = "gpt-3.5-turbo",
@ -109,6 +110,13 @@ def converse(
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
elif conversation_command == ConversationCommand.Online:
conversation_primer = prompts.online_search_conversation.format(
query=user_query, online_results=str(online_results)
)
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
conversation_primer = prompts.general_conversation.format(query=user_query)
else:
@ -130,6 +138,7 @@ def converse(
return chat_completion_with_backoff(
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,
temperature=temperature,
openai_api_key=api_key,

View file

@ -69,9 +69,16 @@ def completion_with_backoff(**kwargs):
reraise=True,
)
def chat_completion_with_backoff(
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None
messages,
compiled_references,
online_results,
model_name,
temperature,
openai_api_key=None,
completion_func=None,
model_kwargs=None,
):
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
t.start()
return g

View file

@ -10,7 +10,7 @@ You are Khoj, a smart, inquisitive and helpful personal assistant.
Use your general knowledge and the past conversation with the user as context to inform your responses.
You were created by Khoj Inc. with the following capabilities:
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you. They can share files with you using the Khoj desktop application.
- You cannot set reminders.
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
@ -35,6 +35,12 @@ no_notes_found = PromptTemplate.from_template(
""".strip()
)
no_online_results_found = PromptTemplate.from_template(
"""
I'm sorry, I couldn't find any relevant information from the internet to respond to your message.
""".strip()
)
no_entries_found = PromptTemplate.from_template(
"""
It looks like you haven't added any notes yet. No worries, you can fix that by downloading the Khoj app from <a href=https://khoj.dev/downloads>here</a>.
@ -103,6 +109,18 @@ Question: {query}
""".strip()
)
## Online Search Conversation
## --
online_search_conversation = PromptTemplate.from_template(
"""
Use this up-to-date information from the internet to inform your response.
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations.
Information from the internet: {online_results}
Query: {query}""".strip()
)
## Summarize Notes
## --

View file

@ -29,9 +29,10 @@ model_to_tokenizer = {
class ThreadedGenerator:
def __init__(self, compiled_references, completion_func=None):
def __init__(self, compiled_references, online_results, completion_func=None):
self.queue = queue.Queue()
self.compiled_references = compiled_references
self.online_results = online_results
self.completion_func = completion_func
self.response = ""
self.start_time = perf_counter()
@ -62,6 +63,8 @@ class ThreadedGenerator:
def close(self):
if self.compiled_references and len(self.compiled_references) > 0:
self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}")
elif self.online_results and len(self.online_results) > 0:
self.queue.put(f"### compiled references:{json.dumps(self.online_results)}")
self.queue.put(StopIteration)

View file

View file

@ -0,0 +1,38 @@
import requests
import json
import os
import logging
logger = logging.getLogger(__name__)
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
url = "https://google.serper.dev/search"
def search_with_google(query: str):
if SERPER_DEV_API_KEY is None:
raise ValueError("SERPER_DEV_API_KEY is not set")
payload = json.dumps(
{
"q": query,
}
)
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code != 200:
logger.error(response.text)
return {}
response = response.json()
response_dict = {}
response_dict["knowledgeGraph"] = response.get("knowledgeGraph", {})
response_dict["organic"] = response.get("organic", [])
response_dict["answerBox"] = response.get("answerBox", [])
response_dict["peopleAlsoAsk"] = response.get("peopleAlsoAsk", [])
return response_dict

View file

@ -4,7 +4,7 @@ import math
import time
import logging
import json
from typing import List, Optional, Union, Any
from typing import List, Optional, Union, Any, Dict
# External Packages
from fastapi import APIRouter, Depends, HTTPException, Header, Request
@ -41,6 +41,7 @@ from khoj.routers.helpers import (
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
from khoj.processor.tools.online_search import search_with_google
from fastapi.requests import Request
from database import adapters
@ -602,6 +603,7 @@ async def chat(
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, meta_log, q, (n or 5), (d or math.inf), conversation_command
)
online_results: Dict = dict()
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
conversation_command = ConversationCommand.General
@ -618,11 +620,22 @@ async def chat(
no_entries_found_format = no_entries_found.format()
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
elif conversation_command == ConversationCommand.Online:
try:
online_results = search_with_google(defiltered_query)
except ValueError as e:
return StreamingResponse(
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),
media_type="text/event-stream",
status_code=200,
)
# Get the (streamed) chat response from the LLM of choice.
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
compiled_references,
online_results,
inferred_queries,
conversation_command,
user,
@ -677,7 +690,7 @@ async def extract_references_and_questions(
compiled_references: List[Any] = []
inferred_queries: List[str] = []
if conversation_type == ConversationCommand.General:
if conversation_type == ConversationCommand.General or conversation_type == ConversationCommand.Online:
return compiled_references, inferred_queries, q
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):

View file

@ -6,7 +6,7 @@ from datetime import datetime
from functools import partial
import logging
from time import time
from typing import Iterator, List, Optional, Union, Tuple, Dict
from typing import Iterator, List, Optional, Union, Tuple, Dict, Any
# External Packages
from fastapi import HTTPException, Request
@ -96,6 +96,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.Help
elif query.startswith("/general"):
return ConversationCommand.General
elif query.startswith("/online"):
return ConversationCommand.Online
# If no relevant notes found for the given query
elif not any_references:
return ConversationCommand.General
@ -116,6 +118,7 @@ def generate_chat_response(
q: str,
meta_log: dict,
compiled_references: List[str] = [],
online_results: Dict[str, Any] = {},
inferred_queries: List[str] = [],
conversation_command: ConversationCommand = ConversationCommand.Default,
user: KhojUser = None,
@ -125,6 +128,7 @@ def generate_chat_response(
chat_response: str,
user_message_time: str,
compiled_references: List[str],
online_results: Dict[str, Any],
inferred_queries: List[str],
meta_log,
):
@ -132,7 +136,11 @@ def generate_chat_response(
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
khoj_message_metadata={
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries},
"onlineContext": online_results,
},
conversation_log=meta_log.get("chat", []),
)
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
@ -150,6 +158,7 @@ def generate_chat_response(
q,
user_message_time=user_message_time,
compiled_references=compiled_references,
online_results=online_results,
inferred_queries=inferred_queries,
meta_log=meta_log,
)
@ -166,6 +175,7 @@ def generate_chat_response(
loaded_model = state.gpt4all_processor_config.loaded_model
chat_response = converse_offline(
references=compiled_references,
online_results=online_results,
user_query=q,
loaded_model=loaded_model,
conversation_log=meta_log,
@ -181,6 +191,7 @@ def generate_chat_response(
chat_model = conversation_config.chat_model
chat_response = converse(
compiled_references,
online_results,
q,
meta_log,
model=chat_model,

View file

@ -279,6 +279,7 @@ command_descriptions = {
ConversationCommand.General: "Only talk about information that relies on Khoj's general knowledge, not your personal knowledge base.",
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Online: "Look up information on the internet.",
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
}