diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 44ba4ec2..036e798b 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -205,7 +205,7 @@ def configure_conversation_processor(conversation_processor_config): else: # Initialize Conversation Logs conversation_processor.meta_log = {} - conversation_processor.chat_session = "" + conversation_processor.chat_session = [] return conversation_processor @@ -225,9 +225,9 @@ def save_chat_session(): chat_session = state.processor_config.conversation.chat_session openai_api_key = state.processor_config.conversation.openai_api_key conversation_log = state.processor_config.conversation.meta_log - model = state.processor_config.conversation.model + chat_model = state.processor_config.conversation.chat_model session = { - "summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key), + "summary": summarize(chat_session, model=chat_model, api_key=openai_api_key), "session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"], "session-end": len(conversation_log["chat"]), } @@ -242,7 +242,7 @@ def save_chat_session(): with open(conversation_logfile, "w+", encoding="utf-8") as logfile: json.dump(conversation_log, logfile, indent=2) - state.processor_config.conversation.chat_session = None + state.processor_config.conversation.chat_session = [] logger.info("📩 Saved current chat session to conversation logs") diff --git a/src/khoj/interface/web/processor_conversation_input.html b/src/khoj/interface/web/processor_conversation_input.html index 8f04a3b3..24cbc666 100644 --- a/src/khoj/interface/web/processor_conversation_input.html +++ b/src/khoj/interface/web/processor_conversation_input.html @@ -16,9 +16,17 @@ + + + + + + + + - - +
+ @@ -34,14 +42,6 @@ - - - -
- - - -
diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index 8051afdb..226af3fb 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -1,9 +1,11 @@ # Standard Packages -import json import logging from datetime import datetime from typing import Optional +# External Packages +from langchain.schema import ChatMessage + # Internal Packages from khoj.utils.constants import empty_escape_sequences from khoj.processor.conversation import prompts @@ -17,44 +19,16 @@ from khoj.processor.conversation.utils import ( logger = logging.getLogger(__name__) -def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=500): +def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200): """ - Answer user query using provided text as reference with OpenAI's GPT + Summarize conversation session using the specified OpenAI chat model """ - # Setup Prompt from arguments - prompt = prompts.answer.format(text=text, user_query=user_query) + messages = [ChatMessage(content=prompts.summarize_chat.format(), role="system")] + session # Get Response from GPT - logger.debug(f"Prompt for GPT: {prompt}") + logger.debug(f"Prompt for GPT: {messages}") response = completion_with_backoff( - prompt=prompt, - model_name=model, - temperature=temperature, - max_tokens=max_tokens, - model_kwargs={"stop": ['"""']}, - openai_api_key=api_key, - ) - - # Extract, Clean Message from GPT's Response - return str(response).replace("\n\n", "") - - -def summarize(text, summary_type, model, user_query=None, api_key=None, temperature=0.5, max_tokens=200): - """ - Summarize user input using OpenAI's GPT - """ - # Setup Prompt based on Summary Type - if summary_type == "chat": - prompt = prompts.summarize_chat.format(text=text) - elif summary_type == "notes": - prompt = prompts.summarize_notes.format(text=text, user_query=user_query) - else: - raise ValueError(f"Invalid summary type: {summary_type}") - - # Get Response from GPT - logger.debug(f"Prompt for GPT: {prompt}") - response = completion_with_backoff( - prompt=prompt, + messages=messages, model_name=model, temperature=temperature, max_tokens=max_tokens, @@ -64,11 +38,11 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat ) # Extract, Clean Message from GPT's Response - return str(response).replace("\n\n", "") + return str(response.content).replace("\n\n", "") def extract_questions( - text, model: Optional[str] = "text-davinci-003", conversation_log={}, api_key=None, temperature=0, max_tokens=100 + text, model: Optional[str] = "gpt-4", conversation_log={}, api_key=None, temperature=0, max_tokens=100 ): """ Infer search queries to retrieve relevant notes to answer user query @@ -97,10 +71,11 @@ def extract_questions( chat_history=chat_history, text=text, ) + messages = [ChatMessage(content=prompt, role="assistant")] # Get Response from GPT response = completion_with_backoff( - prompt=prompt, + messages=messages, model_name=model, temperature=temperature, max_tokens=max_tokens, @@ -111,7 +86,7 @@ def extract_questions( # Extract, Clean Message from GPT's Response try: questions = ( - response.strip(empty_escape_sequences) + response.content.strip(empty_escape_sequences) .replace("['", '["') .replace("']", '"]') .replace("', '", '", "') @@ -126,31 +101,6 @@ def extract_questions( return questions -def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=100, verbose=0): - """ - Extract search type from user query using OpenAI's GPT - """ - # Setup Prompt to extract search type - prompt = prompts.search_type + f"{text}\nA:" - if verbose > 1: - print(f"Message -> Prompt: {text} -> {prompt}") - - # Get Response from GPT - logger.debug(f"Prompt for GPT: {prompt}") - response = completion_with_backoff( - prompt=prompt, - model_name=model, - temperature=temperature, - max_tokens=max_tokens, - frequency_penalty=0.2, - model_kwargs={"stop": ["\n"]}, - openai_api_key=api_key, - ) - - # Extract, Clean Message from GPT's Response - return json.loads(response.strip(empty_escape_sequences)) - - def converse( references, user_query, diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 04b5d108..c04e9042 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -37,12 +37,7 @@ Question: {query} ## Summarize Chat ## -- summarize_chat = PromptTemplate.from_template( - """ -You are an AI. Summarize the conversation below from your perspective: - -{text} - -Summarize the conversation from the AI's first-person perspective:""" + f"{personality.format()} Summarize the conversation from your first person perspective" ) @@ -102,7 +97,7 @@ A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi. Q: What national parks did I go to last year? -["National park I visited in {last_new_year} dt>="{last_new_year_date}" dt<"{current_new_year_date}""] +["National park I visited in {last_new_year} dt>='{last_new_year_date}' dt<'{current_new_year_date}'"] A: You visited the Grand Canyon and Yellowstone National Park in {last_new_year}. diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 99084bf0..e77b7899 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -9,7 +9,6 @@ import json # External Packages from langchain.chat_models import ChatOpenAI -from langchain.llms import OpenAI from langchain.schema import ChatMessage from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.base import BaseCallbackManager @@ -89,13 +88,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler): reraise=True, ) def completion_with_backoff(**kwargs): - prompt = kwargs.pop("prompt") - if "api_key" in kwargs: - kwargs["openai_api_key"] = kwargs.get("api_key") - else: + messages = kwargs.pop("messages") + if not "openai_api_key" in kwargs: kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") - llm = OpenAI(**kwargs, request_timeout=20, max_retries=1) - return llm(prompt) + llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1) + return llm(messages=messages) @retry( @@ -126,11 +123,12 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None): streaming=True, verbose=True, callback_manager=BaseCallbackManager([callback_handler]), - model_name=model_name, + model_name=model_name, # type: ignore temperature=temperature, openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), request_timeout=20, max_retries=1, + client=None, ) chat(messages=messages) @@ -196,15 +194,6 @@ def reciprocal_conversation_to_chatml(message_pair): return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])] -def message_to_prompt( - user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:" -): - """Create prompt for GPT from messages and conversation history""" - gpt_message = f" {gpt_message}" if gpt_message else "" - - return f"{conversation_history}{restart_sequence} {user_message}{start_sequence}{gpt_message}" - - def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]): """Create json logs from messages, metadata for conversation log""" default_khoj_message_metadata = { diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index e87a80d0..7ccdf5cd 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -15,7 +15,7 @@ from sentence_transformers import util # Internal Packages from khoj.configure import configure_processor, configure_search from khoj.processor.conversation.gpt import converse, extract_questions -from khoj.processor.conversation.utils import message_to_log, message_to_prompt +from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml from khoj.search_type import image_search, text_search from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter @@ -448,10 +448,9 @@ async def chat( user_message_time: str, compiled_references: List[str], inferred_queries: List[str], - chat_session: str, meta_log, ): - state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) + state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response]) state.processor_config.conversation.meta_log["chat"] = message_to_log( q, gpt_response, @@ -470,7 +469,6 @@ async def chat( ) # Load Conversation History - chat_session = state.processor_config.conversation.chat_session meta_log = state.processor_config.conversation.meta_log # If user query is empty, return nothing @@ -479,7 +477,6 @@ async def chat( # Initialize Variables api_key = state.processor_config.conversation.openai_api_key - model = state.processor_config.conversation.model chat_model = state.processor_config.conversation.chat_model user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") conversation_type = "general" if q.startswith("@general") else "notes" @@ -489,7 +486,7 @@ async def chat( if conversation_type == "notes": # Infer search queries from user message with timer("Extracting search queries took", logger): - inferred_queries = extract_questions(q, model=model, api_key=api_key, conversation_log=meta_log) + inferred_queries = extract_questions(q, api_key=api_key, conversation_log=meta_log) # Collate search results as context for GPT with timer("Searching knowledge base took", logger): @@ -525,7 +522,6 @@ async def chat( user_message_time=user_message_time, compiled_references=compiled_references, inferred_queries=inferred_queries, - chat_session=chat_session, meta_log=meta_log, ) diff --git a/src/khoj/routers/api_beta.py b/src/khoj/routers/api_beta.py index 01b7a086..83cca612 100644 --- a/src/khoj/routers/api_beta.py +++ b/src/khoj/routers/api_beta.py @@ -1,64 +1,9 @@ # Standard Packages import logging -from typing import Optional # External Packages from fastapi import APIRouter -# Internal Packages -from khoj.routers.api import search -from khoj.processor.conversation.gpt import ( - answer, - extract_search_type, -) -from khoj.utils.state import SearchType -from khoj.utils.helpers import get_from_dict -from khoj.utils import state - - # Initialize Router api_beta = APIRouter() logger = logging.getLogger(__name__) - - -# Create Routes -@api_beta.get("/search") -def search_beta(q: str, n: Optional[int] = 1): - # Initialize Variables - model = state.processor_config.conversation.model - api_key = state.processor_config.conversation.openai_api_key - - # Extract Search Type using GPT - try: - metadata = extract_search_type(q, model=model, api_key=api_key, verbose=state.verbose) - search_type = get_from_dict(metadata, "search-type") - except Exception as e: - return {"status": "error", "result": [str(e)], "type": None} - - # Search - search_results = search(q, n=n, t=SearchType(search_type)) - - # Return response - return {"status": "ok", "result": search_results, "type": search_type} - - -@api_beta.get("/answer") -def answer_beta(q: str): - # Initialize Variables - model = state.processor_config.conversation.model - api_key = state.processor_config.conversation.openai_api_key - - # Collate context for GPT - result_list = search(q, n=2, r=True, score_threshold=0, dedupe=False) - collated_result = "\n\n".join([f"# {item.additional['compiled']}" for item in result_list]) - logger.debug(f"Reference Context:\n{collated_result}") - - # Make GPT respond to user query using provided context - try: - gpt_response = answer(collated_result, user_query=q, model=model, api_key=api_key) - status = "ok" - except Exception as e: - gpt_response = str(e) - status = "error" - - return {"status": status, "response": gpt_response} diff --git a/src/khoj/search_filter/date_filter.py b/src/khoj/search_filter/date_filter.py index be07eefd..b612f16b 100644 --- a/src/khoj/search_filter/date_filter.py +++ b/src/khoj/search_filter/date_filter.py @@ -23,7 +23,7 @@ class DateFilter(BaseFilter): # - dt>="yesterday" dt<"tomorrow" # - dt>="last week" # - dt:"2 years ago" - date_regex = r"dt([:><=]{1,2})\"(.*?)\"" + date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']" def __init__(self, entry_key="raw"): self.entry_key = entry_key diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 3adc6e9d..155cdcc6 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -72,7 +72,7 @@ class ConversationProcessorConfigModel: self.model = processor_config.model self.chat_model = processor_config.chat_model self.conversation_logfile = Path(processor_config.conversation_logfile) - self.chat_session = "" + self.chat_session: List[str] = [] self.meta_log: dict = {} diff --git a/tests/test_chat_actors.py b/tests/test_chat_actors.py index 9f8e821d..e645fc13 100644 --- a/tests/test_chat_actors.py +++ b/tests/test_chat_actors.py @@ -33,9 +33,9 @@ def test_extract_question_with_date_filter_from_relative_day(): # Assert expected_responses = [ - ('dt="1984-04-01"', ""), - ('dt>="1984-04-01"', 'dt<"1984-04-02"'), - ('dt>"1984-03-31"', 'dt<"1984-04-02"'), + ("dt='1984-04-01'", ""), + ("dt>='1984-04-01'", "dt<'1984-04-02'"), + ("dt>'1984-03-31'", "dt<'1984-04-02'"), ] assert len(response) == 1 assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( @@ -51,7 +51,7 @@ def test_extract_question_with_date_filter_from_relative_month(): response = extract_questions("Which countries did I visit last month?") # Assert - expected_responses = [('dt>="1984-03-01"', 'dt<"1984-04-01"'), ('dt>="1984-03-01"', 'dt<="1984-03-31"')] + expected_responses = [("dt>='1984-03-01'", "dt<'1984-04-01'"), ("dt>='1984-03-01'", "dt<='1984-03-31'")] assert len(response) == 1 assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( "Expected date filter to limit to March 1984 in response but got: " + response[0] @@ -67,9 +67,9 @@ def test_extract_question_with_date_filter_from_relative_year(): # Assert expected_responses = [ - ('dt>="1984-01-01"', ""), - ('dt>="1984-01-01"', 'dt<"1985-01-01"'), - ('dt>="1984-01-01"', 'dt<="1984-12-31"'), + ("dt>='1984-01-01'", ""), + ("dt>='1984-01-01'", "dt<'1985-01-01'"), + ("dt>='1984-01-01'", "dt<='1984-12-31'"), ] assert len(response) == 1 assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( @@ -172,8 +172,8 @@ def test_generate_search_query_with_date_and_context_from_chat_history(): # Assert expected_responses = [ - ('dt>="2000-04-01"', 'dt<"2000-05-01"'), - ('dt>="2000-04-01"', 'dt<="2000-04-31"'), + ("dt>='2000-04-01'", "dt<'2000-05-01'"), + ("dt>='2000-04-01'", "dt<='2000-04-30'"), ] assert len(response) == 1 assert "Masai Mara" in response[0]