diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 4141d3bb..7fae4ac7 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -4,6 +4,7 @@ from datetime import date, datetime import secrets from typing import Type, List from datetime import date, timezone +import random from django.db import models from django.contrib.sessions.backends.db import SessionStore @@ -339,6 +340,26 @@ class ConversationAdapters: async def get_openai_chat_config(): return await OpenAIProcessorConversationConfig.objects.filter().afirst() + @staticmethod + def get_valid_conversation_config(user: KhojUser): + offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config() + conversation_config = ConversationAdapters.get_conversation_config(user) + if conversation_config is None: + conversation_config = ConversationAdapters.get_default_conversation_config() + + if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline": + if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None: + state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model) + + return conversation_config + + openai_chat_config = ConversationAdapters.get_openai_conversation_config() + if openai_chat_config and conversation_config.model_type == "openai": + return conversation_config + + else: + raise ValueError("Invalid conversation config - either configure offline chat or openai chat") + class EntryAdapters: word_filer = WordFilter() diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index ad4a2c49..21acc4a4 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -137,30 +137,33 @@ function processOnlineReferences(referenceSection, onlineContext) { let numOnlineReferences = 0; - if (onlineContext.organic && onlineContext.organic.length > 0) { - numOnlineReferences += onlineContext.organic.length; - for (let index in onlineContext.organic) { - let reference = onlineContext.organic[index]; - let polishedReference = generateOnlineReference(reference, index); - referenceSection.appendChild(polishedReference); + for (let subquery in onlineContext) { + let onlineReference = onlineContext[subquery]; + if (onlineReference.organic && onlineReference.organic.length > 0) { + numOnlineReferences += onlineReference.organic.length; + for (let index in onlineReference.organic) { + let reference = onlineReference.organic[index]; + let polishedReference = generateOnlineReference(reference, index); + referenceSection.appendChild(polishedReference); + } } - } - if (onlineContext.knowledgeGraph && onlineContext.knowledgeGraph.length > 0) { - numOnlineReferences += onlineContext.knowledgeGraph.length; - for (let index in onlineContext.knowledgeGraph) { - let reference = onlineContext.knowledgeGraph[index]; - let polishedReference = generateOnlineReference(reference, index); - referenceSection.appendChild(polishedReference); + if (onlineReference.knowledgeGraph && onlineReference.knowledgeGraph.length > 0) { + numOnlineReferences += onlineReference.knowledgeGraph.length; + for (let index in onlineReference.knowledgeGraph) { + let reference = onlineReference.knowledgeGraph[index]; + let polishedReference = generateOnlineReference(reference, index); + referenceSection.appendChild(polishedReference); + } } - } - if (onlineContext.peopleAlsoAsk && onlineContext.peopleAlsoAsk.length > 0) { - numOnlineReferences += onlineContext.peopleAlsoAsk.length; - for (let index in onlineContext.peopleAlsoAsk) { - let reference = onlineContext.peopleAlsoAsk[index]; - let polishedReference = generateOnlineReference(reference, index); - referenceSection.appendChild(polishedReference); + if (onlineReference.peopleAlsoAsk && onlineReference.peopleAlsoAsk.length > 0) { + numOnlineReferences += onlineReference.peopleAlsoAsk.length; + for (let index in onlineReference.peopleAlsoAsk) { + let reference = onlineReference.peopleAlsoAsk[index]; + let polishedReference = generateOnlineReference(reference, index); + referenceSection.appendChild(polishedReference); + } } } @@ -316,15 +319,28 @@ let referenceExpandButton = document.createElement('button'); referenceExpandButton.classList.add("reference-expand-button"); - let expandButtonText = rawReferenceAsJson.length == 1 ? "1 reference" : `${rawReferenceAsJson.length} references`; - referenceExpandButton.innerHTML = expandButtonText; - references.appendChild(referenceExpandButton); let referenceSection = document.createElement('div'); referenceSection.classList.add("reference-section"); referenceSection.classList.add("collapsed"); + let numReferences = 0; + + // If rawReferenceAsJson is a list, then count the length + if (Array.isArray(rawReferenceAsJson)) { + numReferences = rawReferenceAsJson.length; + + rawReferenceAsJson.forEach((reference, index) => { + let polishedReference = generateReference(reference, index); + referenceSection.appendChild(polishedReference); + }); + } else { + numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + } + + references.appendChild(referenceExpandButton); + referenceExpandButton.addEventListener('click', function() { if (referenceSection.classList.contains("collapsed")) { referenceSection.classList.remove("collapsed"); @@ -335,10 +351,8 @@ } }); - rawReferenceAsJson.forEach((reference, index) => { - let polishedReference = generateReference(reference, index); - referenceSection.appendChild(polishedReference); - }); + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; references.appendChild(referenceSection); readStream(); } else { @@ -419,7 +433,7 @@ const khojToken = await window.tokenAPI.getToken(); const headers = { 'Authorization': `Bearer ${khojToken}` }; - fetch(`${hostURL}/api/chat/history?client=web`, { headers }) + fetch(`${hostURL}/api/chat/history?client=desktop`, { headers }) .then(response => response.json()) .then(data => { if (data.detail) { diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index b27312aa..7afa7bee 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -147,35 +147,38 @@ To get started, just start typing below. You can also type / to see a list of co function processOnlineReferences(referenceSection, onlineContext) { let numOnlineReferences = 0; - if (onlineContext.organic && onlineContext.organic.length > 0) { - numOnlineReferences += onlineContext.organic.length; - for (let index in onlineContext.organic) { - let reference = onlineContext.organic[index]; + for (let subquery in onlineContext) { + let onlineReference = onlineContext[subquery]; + if (onlineReference.organic && onlineReference.organic.length > 0) { + numOnlineReferences += onlineReference.organic.length; + for (let index in onlineReference.organic) { + let reference = onlineReference.organic[index]; let polishedReference = generateOnlineReference(reference, index); referenceSection.appendChild(polishedReference); + } } - } - if (onlineContext.knowledgeGraph && onlineContext.knowledgeGraph.length > 0) { - numOnlineReferences += onlineContext.knowledgeGraph.length; - for (let index in onlineContext.knowledgeGraph) { - let reference = onlineContext.knowledgeGraph[index]; + if (onlineReference.knowledgeGraph && onlineReference.knowledgeGraph.length > 0) { + numOnlineReferences += onlineReference.knowledgeGraph.length; + for (let index in onlineReference.knowledgeGraph) { + let reference = onlineReference.knowledgeGraph[index]; let polishedReference = generateOnlineReference(reference, index); referenceSection.appendChild(polishedReference); + } } - } - if (onlineContext.peopleAlsoAsk && onlineContext.peopleAlsoAsk.length > 0) { - numOnlineReferences += onlineContext.peopleAlsoAsk.length; - for (let index in onlineContext.peopleAlsoAsk) { - let reference = onlineContext.peopleAlsoAsk[index]; - let polishedReference = generateOnlineReference(reference, index); - referenceSection.appendChild(polishedReference); + if (onlineReference.peopleAlsoAsk && onlineReference.peopleAlsoAsk.length > 0) { + numOnlineReferences += onlineReference.peopleAlsoAsk.length; + for (let index in onlineReference.peopleAlsoAsk) { + let reference = onlineReference.peopleAlsoAsk[index]; + let polishedReference = generateOnlineReference(reference, index); + referenceSection.appendChild(polishedReference); + } } - } + } return numOnlineReferences; - } + } function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) { if (context == null && onlineContext == null) { @@ -356,7 +359,6 @@ To get started, just start typing below. You can also type / to see a list of co references = document.createElement('div'); references.classList.add("references"); - let referenceExpandButton = document.createElement('button'); referenceExpandButton.classList.add("reference-expand-button"); diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index 6ab02318..41b1844b 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -179,13 +179,6 @@ def converse_offline( def llm_thread(g, messages: List[ChatMessage], model: Any): - try: - from gpt4all import GPT4All - except ModuleNotFoundError as e: - logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") - raise e - - assert isinstance(model, GPT4All), "model should be of type GPT4All" user_message = messages[-1] system_message = messages[0] conversation_history = messages[1:-1] @@ -204,7 +197,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any): prompted_message = templated_system_message + chat_history + templated_user_message state.chat_lock.acquire() - response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=512) + response_iterator = send_message_to_model_offline(prompted_message, loaded_model=model, streaming=True) try: for response in response_iterator: if any(stop_word in response.strip() for stop_word in stop_words): @@ -214,3 +207,18 @@ def llm_thread(g, messages: List[ChatMessage], model: Any): finally: state.chat_lock.release() g.close() + + +def send_message_to_model_offline( + message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False +): + try: + from gpt4all import GPT4All + except ModuleNotFoundError as e: + logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") + raise e + + assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None" + gpt4all_model = loaded_model or GPT4All(model) + + return gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 909c304f..fed110f7 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -100,6 +100,27 @@ def extract_questions( return questions +def send_message_to_model( + message, + api_key, + model, +): + """ + Send message to model + """ + messages = [ChatMessage(content=message, role="assistant")] + + # Get Response from GPT + return completion_with_backoff( + messages=messages, + model_name=model, + temperature=0, + max_tokens=100, + model_kwargs={"stop": ["A: ", "\n"]}, + openai_api_key=api_key, + ) + + def converse( references, online_results, diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 7fe1ce48..cba7cb59 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -121,6 +121,32 @@ Information from the internet: {online_results} Query: {query}""".strip() ) +online_search_conversation_subqueries = PromptTemplate.from_template( + """ +The user has a question which you can use the internet to respond to. Can you break down the question into subqueries to get the correct answer? Provide search queries as a JSON list of strings + +Today's date in UTC: {current_date} + +Here are some examples of questions and subqueries: +Q: What is the weather like in New York? +A: ["weather in new york"] + +Q: What is the weather like in New York and San Francisco? +A: ["weather in new york", "weather in san francisco"] + +Q: What is the latest news about Google stock? +A: ["google stock news"] + +Q: When is the next lunar eclipse? +A: ["next lunar eclipse"] + +Q: How many oranges would fit in NASA's Saturn V rocket? +A: ["volume of an orange", "volume of saturn v rocket"] + +This is the user's query: +Q: {query} +A: """.strip() +) ## Summarize Notes ## -- diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 7ac8273e..f014fb4a 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -3,6 +3,8 @@ import json import os import logging +from khoj.routers.helpers import generate_online_subqueries + logger = logging.getLogger(__name__) SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY") @@ -10,29 +12,41 @@ SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY") url = "https://google.serper.dev/search" -def search_with_google(query: str): +async def search_with_google(query: str): + def _search_with_google(subquery: str): + payload = json.dumps( + { + "q": subquery, + } + ) + + headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"} + + response = requests.request("POST", url, headers=headers, data=payload) + + if response.status_code != 200: + logger.error(response.text) + return {} + + json_response = response.json() + sub_response_dict = {} + sub_response_dict["knowledgeGraph"] = json_response.get("knowledgeGraph", {}) + sub_response_dict["organic"] = json_response.get("organic", []) + sub_response_dict["answerBox"] = json_response.get("answerBox", []) + sub_response_dict["peopleAlsoAsk"] = json_response.get("peopleAlsoAsk", []) + + return sub_response_dict + if SERPER_DEV_API_KEY is None: raise ValueError("SERPER_DEV_API_KEY is not set") - payload = json.dumps( - { - "q": query, - } - ) - - headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"} - - response = requests.request("POST", url, headers=headers, data=payload) - - if response.status_code != 200: - logger.error(response.text) - return {} - - json_response = response.json() + # Breakdown the query into subqueries to get the correct answer + subqueries = await generate_online_subqueries(query) response_dict = {} - response_dict["knowledgeGraph"] = json_response.get("knowledgeGraph", {}) - response_dict["organic"] = json_response.get("organic", []) - response_dict["answerBox"] = json_response.get("answerBox", []) - response_dict["peopleAlsoAsk"] = json_response.get("peopleAlsoAsk", []) + + for subquery in subqueries: + logger.info(f"Searching with Google for '{subquery}'") + response_dict[subquery] = _search_with_google(subquery) + return response_dict diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index d9b80756..10db3a83 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -597,7 +597,7 @@ async def chat( elif conversation_command == ConversationCommand.Online: try: - online_results = search_with_google(defiltered_query) + online_results = await search_with_google(defiltered_query) except ValueError as e: return StreamingResponse( iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]), diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index b609e977..cdb91b6f 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -6,8 +6,13 @@ from datetime import datetime from functools import partial import logging from time import time +import json from typing import Annotated, Iterator, List, Optional, Union, Tuple, Dict, Any +from datetime import datetime + +from khoj.processor.conversation import prompts + # External Packages from fastapi import HTTPException, Header, Request, Depends @@ -15,10 +20,10 @@ from fastapi import HTTPException, Header, Request, Depends from khoj.utils import state from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.helpers import ConversationCommand, log_telemetry -from khoj.processor.conversation.openai.gpt import converse -from khoj.processor.conversation.gpt4all.chat_model import converse_offline +from khoj.processor.conversation.openai.gpt import converse, send_message_to_model +from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator -from database.models import KhojUser, Subscription +from database.models import KhojUser, Subscription, ChatModelOptions from database.adapters import ConversationAdapters @@ -114,6 +119,65 @@ async def agenerate_chat_response(*args): return await loop.run_in_executor(executor, generate_chat_response, *args) +async def generate_online_subqueries(q: str) -> List[str]: + """ + Generate subqueries from the given query + """ + utc_date = datetime.utcnow().strftime("%Y-%m-%d") + online_queries_prompt = prompts.online_search_conversation_subqueries.format( + current_date=utc_date, + query=q, + ) + + response = await send_message_to_model_wrapper(online_queries_prompt) + + # Validate that the response is a non-empty, JSON-serializable list + try: + response = response.strip() + response = json.loads(response) + response = [q.strip() for q in response if q.strip()] + if not isinstance(response, list) or not response or len(response) == 0: + logger.error(f"Invalid response for constructing subqueries: {response}") + return [q] + return response + except Exception as e: + logger.error(f"Invalid response for constructing subqueries: {response}") + return [q] + + +async def send_message_to_model_wrapper( + message: str, +): + conversation_config = await ConversationAdapters.aget_default_conversation_config() + + if conversation_config is None: + raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") + + if conversation_config.model_type == "offline": + if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None: + state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model) + + loaded_model = state.gpt4all_processor_config.loaded_model + return send_message_to_model_offline( + message=message, + loaded_model=loaded_model, + model=conversation_config.chat_model, + streaming=False, + ) + + elif conversation_config.model_type == "openai": + openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() + api_key = openai_chat_config.api_key + chat_model = conversation_config.chat_model + return send_message_to_model( + message=message, + api_key=api_key, + model=chat_model, + ) + else: + raise HTTPException(status_code=500, detail="Invalid conversation config") + + def generate_chat_response( q: str, meta_log: dict, @@ -163,12 +227,8 @@ def generate_chat_response( meta_log=meta_log, ) - offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config() - conversation_config = ConversationAdapters.get_conversation_config(user) - if conversation_config is None: - conversation_config = ConversationAdapters.get_default_conversation_config() - openai_chat_config = ConversationAdapters.get_openai_conversation_config() - if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline": + conversation_config = ConversationAdapters.get_valid_conversation_config(user) + if conversation_config.model_type == "offline": if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None: state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model) @@ -186,7 +246,8 @@ def generate_chat_response( tokenizer_name=conversation_config.tokenizer, ) - elif openai_chat_config and conversation_config.model_type == "openai": + elif conversation_config.model_type == "openai": + openai_chat_config = ConversationAdapters.get_openai_conversation_config() api_key = openai_chat_config.api_key chat_model = conversation_config.chat_model chat_response = converse(