diff --git a/pyproject.toml b/pyproject.toml index 4b0c4140..9924dc0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ dependencies = [ "uvicorn == 0.17.6", "aiohttp ~= 3.9.0", "langchain <= 0.2.0", + "langchain-openai >= 0.0.5", "requests >= 2.26.0", "bs4 >= 0.0.1", "anyio == 3.7.1", diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 45b16059..462ae05d 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -16,9 +16,9 @@ Hi, I am Khoj, your open, personal AI 👋🏽. I can help: - 🧠 Answer general knowledge questions - 💡 Be a sounding board for your ideas - 📜 Chat with your notes & documents -- 🌄 Generate images based on your messages -- 🔎 Search the web for answers to your questions -- 🎙️ Listen to your audio messages +- 🌄 Generate images based on your messages (start your prompt with "/image") +- 🔎 Search the web for answers to your questions (start your prompt with "/online") +- 🎙️ Listen to your audio messages (use the mic by the input box to speak your message) Get the Khoj [Desktop](https://khoj.dev/downloads), [Obsidian](https://docs.khoj.dev/#/obsidian?id=setup) or [Emacs](https://docs.khoj.dev/#/emacs?id=setup) app to search, chat with your 🖥️ computer docs. diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 3e0f5380..8361650c 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -155,8 +155,13 @@ def converse_offline( completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) elif conversation_command == ConversationCommand.Online: + simplified_online_results = online_results.copy() + for result in online_results: + if online_results[result].get("extracted_content"): + simplified_online_results[result] = online_results[result]["extracted_content"] + conversation_primer = prompts.online_search_conversation.format( - query=user_query, online_results=str(online_results) + query=user_query, online_results=str(simplified_online_results) ) elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message): conversation_primer = user_query @@ -213,7 +218,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any): def send_message_to_model_offline( - message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False + message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False, system_message="" ): try: from gpt4all import GPT4All @@ -224,4 +229,6 @@ def send_message_to_model_offline( assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None" gpt4all_model = loaded_model or GPT4All(model) - return gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming) + return gpt4all_model.generate( + system_message + message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming + ) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 3038aa00..a9030926 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -97,22 +97,18 @@ def extract_questions( def send_message_to_model( - message, + messages, api_key, model, ): """ Send message to model """ - messages = [ChatMessage(content=message, role="assistant")] # Get Response from GPT return completion_with_backoff( messages=messages, - model_name=model, - temperature=0, - max_tokens=100, - model_kwargs={"stop": ["A: ", "\n"]}, + model=model, openai_api_key=api_key, ) @@ -120,7 +116,7 @@ def send_message_to_model( def converse( references, user_query, - online_results=[], + online_results: Optional[dict] = None, conversation_log={}, model: str = "gpt-3.5-turbo", api_key: Optional[str] = None, @@ -145,8 +141,13 @@ def converse( completion_func(chat_response=prompts.no_online_results_found.format()) return iter([prompts.no_online_results_found.format()]) elif conversation_command == ConversationCommand.Online: + simplified_online_results = online_results.copy() + for result in online_results: + if online_results[result].get("extracted_content"): + simplified_online_results[result] = online_results[result]["extracted_content"] + conversation_primer = prompts.online_search_conversation.format( - query=user_query, online_results=str(online_results) + query=user_query, online_results=str(simplified_online_results) ) elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references): conversation_primer = prompts.general_conversation.format(query=user_query) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index c0a9929c..00ad74ce 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -6,7 +6,7 @@ from typing import Any import openai from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from langchain_community.chat_models import ChatOpenAI +from langchain_openai import ChatOpenAI from tenacity import ( before_sleep_log, retry, @@ -48,7 +48,10 @@ def completion_with_backoff(**kwargs): if not "openai_api_key" in kwargs: kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1) - return llm(messages=messages) + aggregated_response = "" + for chunk in llm.stream(messages): + aggregated_response += chunk.content + return aggregated_response @retry( diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index b9f6dc23..67d8725f 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -257,18 +257,39 @@ Q: {text} """ ) +system_prompt_extract_relevant_information = """As a professional analyst, create a comprehensive report of the most relevant information from a web page in response to a user's query. The text provided is directly from within the web page. The report you create should be multiple paragraphs, and it should represent the content of the website. Tell the user exactly what the website says in response to their query, while adhering to these guidelines: + +1. Answer the user's query as specifically as possible. Include many supporting details from the website. +2. Craft a report that is detailed, thorough, in-depth, and complex, while maintaining clarity. +3. Rely strictly on the provided text, without including external information. +4. Format the report in multiple paragraphs with a clear structure. +5. Be as specific as possible in your answer to the user's query. +6. Reproduce as much of the provided text as possible, while maintaining readability. +""".strip() + +extract_relevant_information = PromptTemplate.from_template( + """ +Target Query: {query} + +Web Pages: {corpus} + +Collate the relevant information from the website to answer the target query. +""".strip() +) + online_search_conversation_subqueries = PromptTemplate.from_template( """ -You are Khoj, an extremely smart and helpful search assistant. You are tasked with constructing a search query for Google to answer the user's question. +You are Khoj, an extremely smart and helpful search assistant. You are tasked with constructing **up to three** search queries for Google to answer the user's question. - You will receive the conversation history as context. - Add as much context from the previous questions and answers as required into your search queries. - Break messages into multiple search queries when required to retrieve the relevant information. - You have access to the the whole internet to retrieve information. What Google searches, if any, will you need to perform to answer the user's question? -Provide search queries as a JSON list of strings +Provide search queries as a list of strings Current Date: {current_date} +Here are some examples: History: User: I like to use Hacker News to get my tech news. Khoj: Hacker News is an online forum for sharing and discussing the latest tech news. It is a great place to learn about new technologies and startups. @@ -297,6 +318,7 @@ Khoj: NASA's Saturn V rocket frequently makes lunar trips and has a large cargo Q: How many oranges would fit in NASA's Saturn V rocket? A: ["volume of an orange", "volume of saturn v rocket"] +Now it's your turn to construct a search query for Google to answer the user's question. History: {chat_history} diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index a05605e4..1fca1a9e 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -15,8 +15,10 @@ from khoj.utils.helpers import merge_dicts logger = logging.getLogger(__name__) model_to_prompt_size = { - "gpt-3.5-turbo": 4096, - "gpt-4": 8192, + "gpt-3.5-turbo": 3000, + "gpt-4": 7000, + "gpt-4-1106-preview": 7000, + "gpt-4-turbo-preview": 7000, "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548, "gpt-3.5-turbo-16k": 15000, "mistral-7b-instruct-v0.1.Q4_0.gguf": 1548, @@ -194,6 +196,7 @@ def truncate_messages( assert type(system_message.content) == str current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else "" original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else "" + original_question = f"\n{original_question}" original_question_tokens = len(encoder.encode(original_question)) remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index c9b601fd..fe48bfa7 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -1,17 +1,36 @@ import json import logging import os +from typing import Dict, List, Union import requests -from khoj.routers.helpers import generate_online_subqueries +from khoj.routers.helpers import extract_relevant_info, generate_online_subqueries +from khoj.utils.helpers import is_none_or_empty logger = logging.getLogger(__name__) SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY") OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY") -url = "https://google.serper.dev/search" +SERPER_DEV_URL = "https://google.serper.dev/search" + +OLOSTEP_API_URL = "https://agent.olostep.com/olostep-p2p-incomingAPI" + +OLOSTEP_QUERY_PARAMS = { + "timeout": 35, # seconds + "waitBeforeScraping": 1, # seconds + "saveHtml": False, + "saveMarkdown": True, + "removeCSSselectors": "default", + "htmlTransformer": "none", + "removeImages": True, + "fastLane": True, + # Similar to Stripe's API, the expand parameters avoid the need to make a second API call + # to retrieve the dataset (from the dataset API) if you only need the markdown or html. + "expandMarkdown": True, + "expandHtml": False, +} async def search_with_google(query: str, conversation_history: dict): @@ -24,7 +43,7 @@ async def search_with_google(query: str, conversation_history: dict): headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"} - response = requests.request("POST", url, headers=headers, data=payload) + response = requests.request("POST", SERPER_DEV_URL, headers=headers, data=payload) if response.status_code != 200: logger.error(response.text) @@ -51,4 +70,41 @@ async def search_with_google(query: str, conversation_history: dict): logger.info(f"Searching with Google for '{subquery}'") response_dict[subquery] = _search_with_google(subquery) + extracted_content: Dict[str, List] = {} + if is_none_or_empty(OLOSTEP_API_KEY): + logger.warning("OLOSTEP_API_KEY is not set. Skipping web scraping.") + return response_dict + + for subquery in response_dict: + # If a high quality answer is not found, search the web pages of the first 3 organic results + if is_none_or_empty(response_dict[subquery].get("answerBox")): + extracted_content[subquery] = [] + for result in response_dict[subquery].get("organic")[:1]: + logger.info(f"Searching web page of '{result['link']}'") + try: + extracted_content[subquery].append(search_with_olostep(result["link"]).strip()) + except Exception as e: + logger.error(f"Error while searching web page of '{result['link']}': {e}", exc_info=True) + continue + extracted_relevant_content = await extract_relevant_info(subquery, extracted_content) + response_dict[subquery]["extracted_content"] = extracted_relevant_content + return response_dict + + +def search_with_olostep(web_url: str) -> str: + if OLOSTEP_API_KEY is None: + raise ValueError("OLOSTEP_API_KEY is not set") + + headers = {"Authorization": f"Bearer {OLOSTEP_API_KEY}"} + + web_scraping_params: Dict[str, Union[str, int, bool]] = OLOSTEP_QUERY_PARAMS.copy() # type: ignore + web_scraping_params["url"] = web_url + + response = requests.request("GET", OLOSTEP_API_URL, params=web_scraping_params, headers=headers) + + if response.status_code != 200: + logger.error(response, exc_info=True) + return None + + return response.json()["markdown_content"] diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index ee51a8a6..d7f9214b 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -14,6 +14,7 @@ from starlette.authentication import has_required_scope from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.models import ( + ChatModelOptions, ClientApplication, KhojUser, Subscription, @@ -27,6 +28,7 @@ from khoj.processor.conversation.offline.chat_model import ( from khoj.processor.conversation.openai.gpt import converse, send_message_to_model from khoj.processor.conversation.utils import ( ThreadedGenerator, + generate_chatml_messages_with_context, save_to_conversation_log, ) from khoj.utils import state @@ -158,6 +160,24 @@ async def generate_online_subqueries(q: str, conversation_history: dict) -> List return [q] +async def extract_relevant_info(q: str, corpus: dict) -> List[str]: + """ + Given a target corpus, extract the most relevant info given a query + """ + + key = list(corpus.keys())[0] + extract_relevant_information = prompts.extract_relevant_information.format( + query=q, + corpus=corpus[key], + ) + + response = await send_message_to_model_wrapper( + extract_relevant_information, prompts.system_prompt_extract_relevant_information + ) + + return response.strip() + + async def generate_better_image_prompt(q: str, conversation_history: str) -> str: """ Generate a better image prompt from the given query @@ -175,22 +195,28 @@ async def generate_better_image_prompt(q: str, conversation_history: str) -> str async def send_message_to_model_wrapper( message: str, + system_message: str = "", ): - conversation_config = await ConversationAdapters.aget_default_conversation_config() + conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config() if conversation_config is None: raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") + truncated_messages = generate_chatml_messages_with_context( + user_message=message, system_message=system_message, model_name=conversation_config.chat_model + ) + if conversation_config.model_type == "offline": if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None: state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model) loaded_model = state.gpt4all_processor_config.loaded_model return send_message_to_model_offline( - message=message, + message=truncated_messages[-1].content, loaded_model=loaded_model, model=conversation_config.chat_model, streaming=False, + system_message=truncated_messages[0].content, ) elif conversation_config.model_type == "openai": @@ -198,12 +224,12 @@ async def send_message_to_model_wrapper( api_key = openai_chat_config.api_key chat_model = conversation_config.chat_model openai_response = send_message_to_model( - message=message, + messages=truncated_messages, api_key=api_key, model=chat_model, ) - return openai_response.content + return openai_response else: raise HTTPException(status_code=500, detail="Invalid conversation config") diff --git a/tests/test_gpt4all_chat_director.py b/tests/test_gpt4all_chat_director.py index 0173ff7b..dd9b2d21 100644 --- a/tests/test_gpt4all_chat_director.py +++ b/tests/test_gpt4all_chat_director.py @@ -57,11 +57,11 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c @pytest.mark.skipif(os.getenv("SERPER_DEV_API_KEY") is None, reason="requires SERPER_DEV_API_KEY") @pytest.mark.chatquality @pytest.mark.django_db(transaction=True) -def test_chat_with_online_content(chat_client): +def test_chat_with_online_content(client_offline_chat): # Act q = "/online give me the link to paul graham's essay how to do great work" encoded_q = quote(q, safe="") - response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true") + response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true") response_message = response.content.decode("utf-8") response_message = response_message.split("### compiled references")[0] @@ -70,7 +70,31 @@ def test_chat_with_online_content(chat_client): expected_responses = ["http://www.paulgraham.com/greatwork.html"] assert response.status_code == 200 assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected assistants name, [K|k]hoj, in response but got: " + response_message + "Expected links or serper not setup in response but got: " + response_message + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.skipif( + os.getenv("SERPER_DEV_API_KEY") is None or os.getenv("OLOSTEP_API_KEY") is None, + reason="requires SERPER_DEV_API_KEY and OLOSTEP_API_KEY", +) +@pytest.mark.chatquality +@pytest.mark.django_db(transaction=True) +def test_chat_with_online_webpage_content(client_offline_chat): + # Act + q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" + encoded_q = quote(q, safe="") + response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true") + response_message = response.content.decode("utf-8") + + response_message = response_message.split("### compiled references")[0] + + # Assert + expected_responses = ["185", "1871", "horse"] + assert response.status_code == 200 + assert any([expected_response in response_message for expected_response in expected_responses]), ( + "Expected links or serper not setup in response but got: " + response_message ) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 79cf1e9d..215c1430 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,3 +1,4 @@ +import os import secrets import numpy as np @@ -6,6 +7,7 @@ import pytest from scipy.stats import linregress from khoj.processor.embeddings import EmbeddingsModel +from khoj.processor.tools.online_search import search_with_olostep from khoj.utils import helpers @@ -80,3 +82,18 @@ def test_encode_docs_memory_leak(): # If slope is positive memory utilization is increasing # Positive threshold of 2, from observing memory usage trend on MPS vs CPU device assert slope < 2, f"Memory leak suspected on {device}. Memory usage increased at ~{slope:.2f} MB per iteration" + + +@pytest.mark.skipif(os.getenv("OLOSTEP_API_KEY") is None, reason="OLOSTEP_API_KEY is not set") +def test_olostep_api(): + # Arrange + website = "https://en.wikipedia.org/wiki/Great_Chicago_Fire" + + # Act + response = search_with_olostep(website) + + # Assert + assert ( + "An alarm sent from the area near the fire also failed to register at the courthouse where the fire watchmen were" + in response + ) diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index a1ece5a7..fc92d4b7 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -73,6 +73,30 @@ def test_chat_with_online_content(chat_client): ) +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.skipif( + os.getenv("SERPER_DEV_API_KEY") is None or os.getenv("OLOSTEP_API_KEY") is None, + reason="requires SERPER_DEV_API_KEY and OLOSTEP_API_KEY", +) +@pytest.mark.chatquality +@pytest.mark.django_db(transaction=True) +def test_chat_with_online_webpage_content(chat_client): + # Act + q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" + encoded_q = quote(q, safe="") + response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true") + response_message = response.content.decode("utf-8") + + response_message = response_message.split("### compiled references")[0] + + # Assert + expected_responses = ["185", "1871", "horse"] + assert response.status_code == 200 + assert any([expected_response in response_message for expected_response in expected_responses]), ( + "Expected links or serper not setup in response but got: " + response_message + ) + + # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) @pytest.mark.chatquality