From 0fcf234f076f237cf59ce03c3893759c79b5bd11 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 17 Nov 2023 16:19:11 -0800 Subject: [PATCH] Add support for using serper.dev for online queries - Use the knowledgeGraph, answerBox, peopleAlsoAsk and organic responses of serper.dev to provide online context for queries made with the /online command - Add it as an additional tool for doing Google searches - Render the results appropriately in the chat web window - Pass appropriate reference data down to the LLM --- src/khoj/interface/web/chat.html | 145 ++++++++++++++++-- .../conversation/gpt4all/chat_model.py | 10 +- src/khoj/processor/conversation/openai/gpt.py | 9 ++ .../processor/conversation/openai/utils.py | 11 +- src/khoj/processor/conversation/prompts.py | 20 ++- src/khoj/processor/conversation/utils.py | 5 +- src/khoj/processor/tools/__init__.py | 0 src/khoj/processor/tools/online_search.py | 38 +++++ src/khoj/routers/api.py | 17 +- src/khoj/routers/helpers.py | 15 +- src/khoj/utils/helpers.py | 1 + 11 files changed, 250 insertions(+), 21 deletions(-) create mode 100644 src/khoj/processor/tools/__init__.py create mode 100644 src/khoj/processor/tools/online_search.py diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 18a03403..4f2b89c8 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -70,6 +70,52 @@ To get started, just start typing below. You can also type / to see a list of co return referenceButton; } + function generateOnlineReference(reference, index) { + + // Generate HTML for Chat Reference + let title = reference.title; + let link = reference.link; + let snippet = reference.snippet; + let question = reference.question; + if (question) { + question = `Question: ${question}

`; + } else { + question = ""; + } + + let linkElement = document.createElement('a'); + linkElement.setAttribute('href', link); + linkElement.setAttribute('target', '_blank'); + linkElement.setAttribute('rel', 'noopener noreferrer'); + linkElement.classList.add("inline-chat-link"); + linkElement.classList.add("reference-link"); + linkElement.setAttribute('title', title); + linkElement.innerHTML = title; + + let referenceButton = document.createElement('button'); + referenceButton.innerHTML = linkElement.outerHTML; + referenceButton.id = `ref-${index}`; + referenceButton.classList.add("reference-button"); + referenceButton.classList.add("collapsed"); + referenceButton.tabIndex = 0; + + // Add event listener to toggle full reference on click + referenceButton.addEventListener('click', function() { + console.log(`Toggling ref-${index}`) + if (this.classList.contains("collapsed")) { + this.classList.remove("collapsed"); + this.classList.add("expanded"); + this.innerHTML = linkElement.outerHTML + `

${question + snippet}`; + } else { + this.classList.add("collapsed"); + this.classList.remove("expanded"); + this.innerHTML = linkElement.outerHTML; + } + }); + + return referenceButton; + } + function renderMessage(message, by, dt=null, annotations=null) { let message_time = formatDate(dt ?? new Date()); let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You"; @@ -99,8 +145,45 @@ To get started, just start typing below. You can also type / to see a list of co chatBody.scrollTop = chatBody.scrollHeight; } - function renderMessageWithReference(message, by, context=null, dt=null) { - if (context == null || context.length == 0) { + function processOnlineReferences(referenceSection, onlineContext) { + let numOnlineReferences = 0; + if (onlineContext.organic && onlineContext.organic.length > 0) { + numOnlineReferences += onlineContext.organic.length; + for (let index in onlineContext.organic) { + let reference = onlineContext.organic[index]; + let polishedReference = generateOnlineReference(reference, index); + referenceSection.appendChild(polishedReference); + } + } + + if (onlineContext.knowledgeGraph && onlineContext.knowledgeGraph.length > 0) { + numOnlineReferences += onlineContext.knowledgeGraph.length; + for (let index in onlineContext.knowledgeGraph) { + let reference = onlineContext.knowledgeGraph[index]; + let polishedReference = generateOnlineReference(reference, index); + referenceSection.appendChild(polishedReference); + } + } + + if (onlineContext.peopleAlsoAsk && onlineContext.peopleAlsoAsk.length > 0) { + numOnlineReferences += onlineContext.peopleAlsoAsk.length; + for (let index in onlineContext.peopleAlsoAsk) { + let reference = onlineContext.peopleAlsoAsk[index]; + let polishedReference = generateOnlineReference(reference, index); + referenceSection.appendChild(polishedReference); + } + } + + return numOnlineReferences; + } + + function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) { + if (context == null && onlineContext == null) { + renderMessage(message, by, dt); + return; + } + + if ((context && context.length == 0) && (onlineContext == null || (onlineContext && onlineContext.organic.length == 0 && onlineContext.knowledgeGraph.length == 0 && onlineContext.peopleAlsoAsk.length == 0))) { renderMessage(message, by, dt); return; } @@ -109,8 +192,11 @@ To get started, just start typing below. You can also type / to see a list of co let referenceExpandButton = document.createElement('button'); referenceExpandButton.classList.add("reference-expand-button"); - let expandButtonText = context.length == 1 ? "1 reference" : `${context.length} references`; - referenceExpandButton.innerHTML = expandButtonText; + let numReferences = 0; + + if (context) { + numReferences += context.length; + } references.appendChild(referenceExpandButton); @@ -136,6 +222,14 @@ To get started, just start typing below. You can also type / to see a list of co referenceSection.appendChild(polishedReference); } } + + if (onlineContext) { + numReferences += processOnlineReferences(referenceSection, onlineContext); + } + + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; + references.appendChild(referenceSection); renderMessage(message, by, dt, references); @@ -177,6 +271,13 @@ To get started, just start typing below. You can also type / to see a list of co // Add the class "chat-response" to each element codeElement.classList.add("chat-response"); }); + + let anchorElements = element.querySelectorAll('a'); + anchorElements.forEach((anchorElement) => { + // Add the class "inline-chat-link" to each element + anchorElement.classList.add("inline-chat-link"); + }); + return element } @@ -258,15 +359,28 @@ To get started, just start typing below. You can also type / to see a list of co let referenceExpandButton = document.createElement('button'); referenceExpandButton.classList.add("reference-expand-button"); - let expandButtonText = rawReferenceAsJson.length == 1 ? "1 reference" : `${rawReferenceAsJson.length} references`; - referenceExpandButton.innerHTML = expandButtonText; - references.appendChild(referenceExpandButton); let referenceSection = document.createElement('div'); referenceSection.classList.add("reference-section"); referenceSection.classList.add("collapsed"); + let numReferences = 0; + + // If rawReferenceAsJson is a list, then count the length + if (Array.isArray(rawReferenceAsJson)) { + numReferences = rawReferenceAsJson.length; + + rawReferenceAsJson.forEach((reference, index) => { + let polishedReference = generateReference(reference, index); + referenceSection.appendChild(polishedReference); + }); + } else { + numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + } + + references.appendChild(referenceExpandButton); + referenceExpandButton.addEventListener('click', function() { if (referenceSection.classList.contains("collapsed")) { referenceSection.classList.remove("collapsed"); @@ -277,10 +391,8 @@ To get started, just start typing below. You can also type / to see a list of co } }); - rawReferenceAsJson.forEach((reference, index) => { - let polishedReference = generateReference(reference, index); - referenceSection.appendChild(polishedReference); - }); + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; references.appendChild(referenceSection); readStream(); } else { @@ -371,7 +483,7 @@ To get started, just start typing below. You can also type / to see a list of co .then(response => { // Render conversation history, if any response.forEach(chat_log => { - renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created)); + renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext); }); }) .catch(err => { @@ -665,6 +777,11 @@ To get started, just start typing below. You can also type / to see a list of co border-bottom: 1px dotted var(--main-text-color); } + a.reference-link { + color: var(--main-text-color); + border-bottom: 1px dotted var(--main-text-color); + } + button.copy-button { display: block; border-radius: 4px; @@ -749,6 +866,10 @@ To get started, just start typing below. You can also type / to see a list of co padding: 0px; } + p { + margin: 0; + } + div.programmatic-output { background-color: #f5f5f5; border: 1px solid #ddd; diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index d3eaa01a..6ab02318 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -121,6 +121,7 @@ def filter_questions(questions: List[str]): def converse_offline( references, + online_results, user_query, conversation_log={}, model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf", @@ -147,6 +148,13 @@ def converse_offline( # Get Conversation Primer appropriate to Conversation Type if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references_message): return iter([prompts.no_notes_found.format()]) + elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results): + completion_func(chat_response=prompts.no_online_results_found.format()) + return iter([prompts.no_online_results_found.format()]) + elif conversation_command == ConversationCommand.Online: + conversation_primer = prompts.online_search_conversation.format( + query=user_query, online_results=str(online_results) + ) elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message): conversation_primer = user_query else: @@ -164,7 +172,7 @@ def converse_offline( tokenizer_name=tokenizer_name, ) - g = ThreadedGenerator(references, completion_func=completion_func) + g = ThreadedGenerator(references, online_results, completion_func=completion_func) t = Thread(target=llm_thread, args=(g, messages, gpt4all_model)) t.start() return g diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index b86ebc6b..128488ac 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -88,6 +88,7 @@ def extract_questions( def converse( references, + online_results, user_query, conversation_log={}, model: str = "gpt-3.5-turbo", @@ -109,6 +110,13 @@ def converse( if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references): completion_func(chat_response=prompts.no_notes_found.format()) return iter([prompts.no_notes_found.format()]) + elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results): + completion_func(chat_response=prompts.no_online_results_found.format()) + return iter([prompts.no_online_results_found.format()]) + elif conversation_command == ConversationCommand.Online: + conversation_primer = prompts.online_search_conversation.format( + query=user_query, online_results=str(online_results) + ) elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references): conversation_primer = prompts.general_conversation.format(query=user_query) else: @@ -130,6 +138,7 @@ def converse( return chat_completion_with_backoff( messages=messages, compiled_references=references, + online_results=online_results, model_name=model, temperature=temperature, openai_api_key=api_key, diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index dce72e1f..69fab7e5 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -69,9 +69,16 @@ def completion_with_backoff(**kwargs): reraise=True, ) def chat_completion_with_backoff( - messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None + messages, + compiled_references, + online_results, + model_name, + temperature, + openai_api_key=None, + completion_func=None, + model_kwargs=None, ): - g = ThreadedGenerator(compiled_references, completion_func=completion_func) + g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs)) t.start() return g diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index f6b84804..7fe1ce48 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -10,7 +10,7 @@ You are Khoj, a smart, inquisitive and helpful personal assistant. Use your general knowledge and the past conversation with the user as context to inform your responses. You were created by Khoj Inc. with the following capabilities: -- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you. +- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you. They can share files with you using the Khoj desktop application. - You cannot set reminders. - Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question. - Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations. @@ -35,6 +35,12 @@ no_notes_found = PromptTemplate.from_template( """.strip() ) +no_online_results_found = PromptTemplate.from_template( + """ + I'm sorry, I couldn't find any relevant information from the internet to respond to your message. + """.strip() +) + no_entries_found = PromptTemplate.from_template( """ It looks like you haven't added any notes yet. No worries, you can fix that by downloading the Khoj app from here. @@ -103,6 +109,18 @@ Question: {query} """.strip() ) +## Online Search Conversation +## -- +online_search_conversation = PromptTemplate.from_template( + """ +Use this up-to-date information from the internet to inform your response. +Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations. + +Information from the internet: {online_results} + +Query: {query}""".strip() +) + ## Summarize Notes ## -- diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b0d401fa..b44195fd 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -29,9 +29,10 @@ model_to_tokenizer = { class ThreadedGenerator: - def __init__(self, compiled_references, completion_func=None): + def __init__(self, compiled_references, online_results, completion_func=None): self.queue = queue.Queue() self.compiled_references = compiled_references + self.online_results = online_results self.completion_func = completion_func self.response = "" self.start_time = perf_counter() @@ -62,6 +63,8 @@ class ThreadedGenerator: def close(self): if self.compiled_references and len(self.compiled_references) > 0: self.queue.put(f"### compiled references:{json.dumps(self.compiled_references)}") + elif self.online_results and len(self.online_results) > 0: + self.queue.put(f"### compiled references:{json.dumps(self.online_results)}") self.queue.put(StopIteration) diff --git a/src/khoj/processor/tools/__init__.py b/src/khoj/processor/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py new file mode 100644 index 00000000..32a2650a --- /dev/null +++ b/src/khoj/processor/tools/online_search.py @@ -0,0 +1,38 @@ +import requests +import json +import os +import logging + +logger = logging.getLogger(__name__) + +SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY") + +url = "https://google.serper.dev/search" + + +def search_with_google(query: str): + if SERPER_DEV_API_KEY is None: + raise ValueError("SERPER_DEV_API_KEY is not set") + + payload = json.dumps( + { + "q": query, + } + ) + + headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"} + + response = requests.request("POST", url, headers=headers, data=payload) + + if response.status_code != 200: + logger.error(response.text) + return {} + + response = response.json() + + response_dict = {} + response_dict["knowledgeGraph"] = response.get("knowledgeGraph", {}) + response_dict["organic"] = response.get("organic", []) + response_dict["answerBox"] = response.get("answerBox", []) + response_dict["peopleAlsoAsk"] = response.get("peopleAlsoAsk", []) + return response_dict diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 85fba38c..c8962938 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -4,7 +4,7 @@ import math import time import logging import json -from typing import List, Optional, Union, Any +from typing import List, Optional, Union, Any, Dict # External Packages from fastapi import APIRouter, Depends, HTTPException, Header, Request @@ -41,6 +41,7 @@ from khoj.routers.helpers import ( from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline +from khoj.processor.tools.online_search import search_with_google from fastapi.requests import Request from database import adapters @@ -602,6 +603,7 @@ async def chat( compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( request, meta_log, q, (n or 5), (d or math.inf), conversation_command ) + online_results: Dict = dict() if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references): conversation_command = ConversationCommand.General @@ -618,11 +620,22 @@ async def chat( no_entries_found_format = no_entries_found.format() return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200) + elif conversation_command == ConversationCommand.Online: + try: + online_results = search_with_google(defiltered_query) + except ValueError as e: + return StreamingResponse( + iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]), + media_type="text/event-stream", + status_code=200, + ) + # Get the (streamed) chat response from the LLM of choice. llm_response, chat_metadata = await agenerate_chat_response( defiltered_query, meta_log, compiled_references, + online_results, inferred_queries, conversation_command, user, @@ -677,7 +690,7 @@ async def extract_references_and_questions( compiled_references: List[Any] = [] inferred_queries: List[str] = [] - if conversation_type == ConversationCommand.General: + if conversation_type == ConversationCommand.General or conversation_type == ConversationCommand.Online: return compiled_references, inferred_queries, q if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index b52098e7..207c87a1 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -6,7 +6,7 @@ from datetime import datetime from functools import partial import logging from time import time -from typing import Iterator, List, Optional, Union, Tuple, Dict +from typing import Iterator, List, Optional, Union, Tuple, Dict, Any # External Packages from fastapi import HTTPException, Request @@ -96,6 +96,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver return ConversationCommand.Help elif query.startswith("/general"): return ConversationCommand.General + elif query.startswith("/online"): + return ConversationCommand.Online # If no relevant notes found for the given query elif not any_references: return ConversationCommand.General @@ -116,6 +118,7 @@ def generate_chat_response( q: str, meta_log: dict, compiled_references: List[str] = [], + online_results: Dict[str, Any] = {}, inferred_queries: List[str] = [], conversation_command: ConversationCommand = ConversationCommand.Default, user: KhojUser = None, @@ -125,6 +128,7 @@ def generate_chat_response( chat_response: str, user_message_time: str, compiled_references: List[str], + online_results: Dict[str, Any], inferred_queries: List[str], meta_log, ): @@ -132,7 +136,11 @@ def generate_chat_response( user_message=q, chat_response=chat_response, user_message_metadata={"created": user_message_time}, - khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}}, + khoj_message_metadata={ + "context": compiled_references, + "intent": {"inferred-queries": inferred_queries}, + "onlineContext": online_results, + }, conversation_log=meta_log.get("chat", []), ) ConversationAdapters.save_conversation(user, {"chat": updated_conversation}) @@ -150,6 +158,7 @@ def generate_chat_response( q, user_message_time=user_message_time, compiled_references=compiled_references, + online_results=online_results, inferred_queries=inferred_queries, meta_log=meta_log, ) @@ -166,6 +175,7 @@ def generate_chat_response( loaded_model = state.gpt4all_processor_config.loaded_model chat_response = converse_offline( references=compiled_references, + online_results=online_results, user_query=q, loaded_model=loaded_model, conversation_log=meta_log, @@ -181,6 +191,7 @@ def generate_chat_response( chat_model = conversation_config.chat_model chat_response = converse( compiled_references, + online_results, q, meta_log, model=chat_model, diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index a41de361..799184ae 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -279,6 +279,7 @@ command_descriptions = { ConversationCommand.General: "Only talk about information that relies on Khoj's general knowledge, not your personal knowledge base.", ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.", ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.", + ConversationCommand.Online: "Look up information on the internet.", ConversationCommand.Help: "Display a help message with all available commands and other metadata.", }