From 13b16a4364abf6114056a96f0a52c8e63736e738 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 3 Oct 2023 16:29:46 -0700 Subject: [PATCH 01/12] Use default Llama 2 supported by GPT4All Remove custom logic to download custom Llama 2 model. This was added as GPT4All didn't support Llama 2 when it was added to Khoj --- .../conversation/gpt4all/chat_model.py | 4 +- .../conversation/gpt4all/model_metadata.py | 3 - .../processor/conversation/gpt4all/utils.py | 71 +------------------ src/khoj/processor/conversation/utils.py | 4 +- src/khoj/utils/config.py | 2 +- tests/test_gpt4all_chat_actors.py | 2 +- 6 files changed, 7 insertions(+), 79 deletions(-) delete mode 100644 src/khoj/processor/conversation/gpt4all/model_metadata.py diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index 9bc9ea52..d713831a 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) def extract_questions_offline( text: str, - model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin", + model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin", loaded_model: Union[Any, None] = None, conversation_log={}, use_history: bool = True, @@ -123,7 +123,7 @@ def converse_offline( references, user_query, conversation_log={}, - model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin", + model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin", loaded_model: Union[Any, None] = None, completion_func=None, conversation_command=ConversationCommand.Default, diff --git a/src/khoj/processor/conversation/gpt4all/model_metadata.py b/src/khoj/processor/conversation/gpt4all/model_metadata.py deleted file mode 100644 index 065e3720..00000000 --- a/src/khoj/processor/conversation/gpt4all/model_metadata.py +++ /dev/null @@ -1,3 +0,0 @@ -model_name_to_url = { - "llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin" -} diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py index 4042fbe2..585df6a6 100644 --- a/src/khoj/processor/conversation/gpt4all/utils.py +++ b/src/khoj/processor/conversation/gpt4all/utils.py @@ -1,24 +1,8 @@ -import os import logging -import requests -import hashlib -from tqdm import tqdm - -from khoj.processor.conversation.gpt4all import model_metadata logger = logging.getLogger(__name__) -expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"} - - -def get_md5_checksum(filename: str): - hash_md5 = hashlib.md5() - with open(filename, "rb") as f: - for chunk in iter(lambda: f.read(8192), b""): - hash_md5.update(chunk) - return hash_md5.hexdigest() - def download_model(model_name: str): try: @@ -27,57 +11,4 @@ def download_model(model_name: str): logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") raise e - url = model_metadata.model_name_to_url.get(model_name) - model_path = os.path.expanduser(f"~/.cache/gpt4all/") - if not url: - logger.debug(f"Model {model_name} not found in model metadata. Skipping download.") - return GPT4All(model_name=model_name, model_path=model_path) - - filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}") - if os.path.exists(filename): - # Check if the user is connected to the internet - try: - requests.get("https://www.google.com/", timeout=5) - except: - logger.debug("User is offline. Disabling allowed download flag") - return GPT4All(model_name=model_name, model_path=model_path, allow_download=False) - return GPT4All(model_name=model_name, model_path=model_path) - - # Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file - tmp_filename = filename + ".tmp" - - try: - os.makedirs(os.path.dirname(tmp_filename), exist_ok=True) - logger.debug(f"Downloading model {model_name} from {url} to {filename}...") - with requests.get(url, stream=True) as r: - r.raise_for_status() - total_size = int(r.headers.get("content-length", 0)) - with open(tmp_filename, "wb") as f, tqdm( - unit="B", # unit string to be displayed. - unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc. - unit_divisor=1024, # is used when unit_scale is true - total=total_size, # the total iteration. - desc=model_name, # prefix to be displayed on progress bar. - ) as progress_bar: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - progress_bar.update(len(chunk)) - - # Verify the checksum - if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename): - logger.error( - f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available." - ) - os.remove(tmp_filename) - raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.") - - # Move the tmp file to the actual file - os.rename(tmp_filename, filename) - logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}") - return GPT4All(model_name) - except Exception as e: - logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True) - # Remove the tmp file if it exists - if os.path.exists(tmp_filename): - os.remove(tmp_filename) - return None + return GPT4All(model_name=model_name) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 4a92c367..ece526c2 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -17,10 +17,10 @@ logger = logging.getLogger(__name__) max_prompt_size = { "gpt-3.5-turbo": 4096, "gpt-4": 8192, - "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548, + "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548, "gpt-3.5-turbo-16k": 15000, } -tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"} +tokenizer = {"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer"} class ThreadedGenerator: diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index a6532346..f06d4c69 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -84,7 +84,7 @@ class SearchModels: @dataclass class GPT4AllProcessorConfig: - chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin" + chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin" loaded_model: Union[Any, None] = None diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py index d7904ff8..32ee4020 100644 --- a/tests/test_gpt4all_chat_actors.py +++ b/tests/test_gpt4all_chat_actors.py @@ -24,7 +24,7 @@ from khoj.processor.conversation.gpt4all.utils import download_model from khoj.processor.conversation.utils import message_to_log -MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin" +MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_0.bin" @pytest.fixture(scope="session") From d1ff812021a4c59a5d67495207ad90a0fe0be44d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 4 Oct 2023 18:42:12 -0700 Subject: [PATCH 02/12] Run GPT4All Chat Model on GPU, when available GPT4All now supports running models on GPU via Vulkan --- src/khoj/processor/conversation/gpt4all/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py index 585df6a6..d5201780 100644 --- a/src/khoj/processor/conversation/gpt4all/utils.py +++ b/src/khoj/processor/conversation/gpt4all/utils.py @@ -11,4 +11,12 @@ def download_model(model_name: str): logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") raise e - return GPT4All(model_name=model_name) + # Use GPU for Chat Model, if available + try: + model = GPT4All(model_name=model_name, device="gpu") + logger.debug("Loaded chat model to GPU.") + except ValueError: + model = GPT4All(model_name=model_name) + logger.debug("Loaded chat model to CPU.") + + return model From a85ff941ca49538ac6090e4d891e72710737744f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 4 Oct 2023 20:39:31 -0700 Subject: [PATCH 03/12] Make offline chat model user configurable Only GPT4All supported Llama v2 models will work given the prompt structure is not currently configurable --- src/khoj/utils/config.py | 3 ++- src/khoj/utils/rawconfig.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index f06d4c69..5accd2ad 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -84,7 +84,7 @@ class SearchModels: @dataclass class GPT4AllProcessorConfig: - chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin" + chat_model: Optional[str] = None loaded_model: Union[Any, None] = None @@ -95,6 +95,7 @@ class ConversationProcessorConfigModel: ): self.openai_model = conversation_config.openai self.gpt4all_model = GPT4AllProcessorConfig() + self.gpt4all_model.chat_model = conversation_config.offline_chat_model self.enable_offline_chat = conversation_config.enable_offline_chat self.conversation_logfile = Path(conversation_config.conversation_logfile) self.chat_session: List[str] = [] diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 0a916db4..30a98354 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -95,6 +95,7 @@ class ConversationProcessorConfig(ConfigBase): conversation_logfile: Path openai: Optional[OpenAIProcessorConfig] enable_offline_chat: Optional[bool] = False + offline_chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin" class ProcessorConfig(ConfigBase): From 56bd69d5af036a09223bd1c3b596fe83443401ef Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 4 Oct 2023 20:42:25 -0700 Subject: [PATCH 04/12] Improve Llama v2 extract questions actor and associated prompt - Format extract questions prompt format with newlines and whitespaces - Make llama v2 extract questions prompt consistent - Remove empty questions extracted by offline extract_questions actor - Update implicit qs extraction unit test for offline search actor --- .../conversation/gpt4all/chat_model.py | 2 +- src/khoj/processor/conversation/prompts.py | 38 +++++++++++-------- tests/test_gpt4all_chat_actors.py | 6 +-- 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index d713831a..e9beaa80 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -113,7 +113,7 @@ def filter_questions(questions: List[str]): ] filtered_questions = [] for q in questions: - if not any([word in q.lower() for word in hint_words]): + if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q): filtered_questions.append(q) return filtered_questions diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 4de3c623..d487609d 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -23,7 +23,7 @@ no_notes_found = PromptTemplate.from_template( """.strip() ) -system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant. +system_prompt_message_llamav2 = f"""You are Khoj, a smart, inquisitive and helpful personal assistant. Using your general knowledge and our past conversations as context, answer the following question. If you do not know the answer, say 'I don't know.'""" @@ -51,13 +51,13 @@ extract_questions_system_prompt_llamav2 = PromptTemplate.from_template( general_conversation_llamav2 = PromptTemplate.from_template( """ -[INST]{query}[/INST] +[INST] {query} [/INST] """.strip() ) chat_history_llamav2_from_user = PromptTemplate.from_template( """ -[INST]{message}[/INST] +[INST] {message} [/INST] """.strip() ) @@ -69,7 +69,7 @@ chat_history_llamav2_from_assistant = PromptTemplate.from_template( conversation_llamav2 = PromptTemplate.from_template( """ -[INST]{query}[/INST] +[INST] {query} [/INST] """.strip() ) @@ -91,7 +91,7 @@ Question: {query} notes_conversation_llamav2 = PromptTemplate.from_template( """ -Notes: +User's Notes: {references} Question: {query} """.strip() @@ -134,19 +134,25 @@ Answer (in second person):""" extract_questions_llamav2_sample = PromptTemplate.from_template( """ -[INST]<>Current Date: {current_date}<>[/INST] -[INST]How was my trip to Cambodia?[/INST][] -[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia? -[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have? -[INST]How many tennis balls fit in the back of a 2002 Honda Civic?[/INST]What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic? -[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}' -[INST]How are you feeling today?[/INST] -[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age? -[INST]<> +[INST] <>Current Date: {current_date}<> [/INST] +[INST] How was my trip to Cambodia? [/INST] +How was my trip to Cambodia? +[INST] Who did I visit the temple with on that trip? [/INST] +Who did I visit the temple with in Cambodia? +[INST] How should I take care of my plants? [/INST] +What kind of plants do I have? What issues do my plants have? +[INST] How many tennis balls fit in the back of a 2002 Honda Civic? [/INST] +What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic? +[INST] What did I do for Christmas last year? [/INST] +What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}' +[INST] How are you feeling today? [/INST] +[INST] Is Alice older than Bob? [/INST] +When was Alice born? What is Bob's age? +[INST] <> Use these notes from the user's previous conversations to provide a response: {chat_history} -<>[/INST] -[INST]{query}[/INST] +<> [/INST] +[INST] {query} [/INST] """ ) diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py index 32ee4020..056618be 100644 --- a/tests/test_gpt4all_chat_actors.py +++ b/tests/test_gpt4all_chat_actors.py @@ -128,15 +128,15 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model): @pytest.mark.chatquality def test_extract_multiple_implicit_questions_from_message(loaded_model): # Act - response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model) + response = extract_questions_offline("Is Carl taller than Ross?", loaded_model=loaded_model) # Assert - expected_responses = ["height", "taller", "shorter", "heights"] + expected_responses = ["height", "taller", "shorter", "heights", "who"] assert len(response) <= 3 for question in response: assert any([expected_response in question.lower() for expected_response in expected_responses]), ( - "Expected chat actor to ask follow-up questions about Morpheus and Neo, but got: " + question + "Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question ) From 1ad8b150e88061d5cea295b610be2185c8532047 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 13 Oct 2023 22:26:59 -0700 Subject: [PATCH 05/12] Add default tokenizer, max_prompt as fallback for non-default offline chat models Pass user configured chat model as argument to use by converse_offline The proper fix for this would allow users to configure the max_prompt and tokenizer to use (while supplying default ones, if none provided) For now, this is a reasonable start. --- pyproject.toml | 4 ++-- src/khoj/processor/conversation/utils.py | 12 +++++++++--- src/khoj/routers/helpers.py | 1 + 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a52fc9b6..e6773b88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,8 +59,8 @@ dependencies = [ "bs4 >= 0.0.1", "anyio == 3.7.1", "pymupdf >= 1.23.3", - "gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'", - "gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'", + "gpt4all >= 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'", + "gpt4all >= 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'", ] dynamic = ["version"] diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index ece526c2..96c4c1c8 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -19,8 +19,12 @@ max_prompt_size = { "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548, "gpt-3.5-turbo-16k": 15000, + "default": 1600, +} +tokenizer = { + "llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer", + "default": "hf-internal-testing/llama-tokenizer", } -tokenizer = {"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer"} class ThreadedGenerator: @@ -105,7 +109,7 @@ def generate_chatml_messages_with_context( messages = user_chatml_message + rest_backnforths + system_chatml_message # Truncate oldest messages from conversation history until under max supported prompt size by model - messages = truncate_messages(messages, max_prompt_size[model_name], model_name) + messages = truncate_messages(messages, max_prompt_size.get(model_name, max_prompt_size["default"]), model_name) # Return message in chronological order return messages[::-1] @@ -116,8 +120,10 @@ def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) if "llama" in model_name: encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name]) - else: + elif "gpt" in model_name: encoder = tiktoken.encoding_for_model(model_name) + else: + encoder = LlamaTokenizerFast.from_pretrained(tokenizer["default"]) system_message = messages.pop() system_message_tokens = len(encoder.encode(system_message.content)) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 267af330..3898d1b8 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -122,6 +122,7 @@ def generate_chat_response( conversation_log=meta_log, completion_func=partial_completion, conversation_command=conversation_command, + model=state.processor_config.conversation.gpt4all_model.chat_model, ) elif state.processor_config.conversation.openai_model: From 247e75595c3377529497597dbd4a0fe4ef6cb0a3 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 14 Oct 2023 16:54:52 -0700 Subject: [PATCH 06/12] Use AutoTokenizer to support more tokenizers --- src/khoj/processor/conversation/utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 96c4c1c8..7bb86887 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -7,7 +7,7 @@ import tiktoken # External packages from langchain.schema import ChatMessage -from transformers import LlamaTokenizerFast +from transformers import AutoTokenizer # Internal Packages import queue @@ -115,15 +115,13 @@ def generate_chatml_messages_with_context( return messages[::-1] -def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]: +def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name: str) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" - if "llama" in model_name: - encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name]) - elif "gpt" in model_name: + if model_name.startswith("gpt-"): encoder = tiktoken.encoding_for_model(model_name) else: - encoder = LlamaTokenizerFast.from_pretrained(tokenizer["default"]) + encoder = AutoTokenizer.from_pretrained(tokenizer.get(model_name, tokenizer["default"])) system_message = messages.pop() system_message_tokens = len(encoder.encode(system_message.content)) From feb4f17e3d3e8aaabcf5a41c3be4f9d1914ec5b8 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 15 Oct 2023 14:19:29 -0700 Subject: [PATCH 07/12] Update chat config schema. Make max_prompt, chat tokenizer configurable This provides flexibility to use non 1st party supported chat models - Create migration script to update khoj.yml config - Put `enable_offline_chat' under new `offline-chat' section Referring code needs to be updated to accomodate this change - Move `offline_chat_model' to `chat-model' under new `offline-chat' section - Put chat `tokenizer` under new `offline-chat' section - Put `max_prompt' under existing `conversation' section As `max_prompt' size effects both openai and offline chat models --- src/khoj/configure.py | 6 +- src/khoj/interface/web/config.html | 14 ++-- .../migrations/migrate_offline_chat_schema.py | 83 +++++++++++++++++++ src/khoj/routers/api.py | 10 +-- src/khoj/routers/helpers.py | 2 +- src/khoj/utils/cli.py | 8 +- src/khoj/utils/config.py | 6 +- src/khoj/utils/rawconfig.py | 10 ++- tests/conftest.py | 4 +- 9 files changed, 119 insertions(+), 24 deletions(-) create mode 100644 src/khoj/migrations/migrate_offline_chat_schema.py diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 7e6cc409..769f015c 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -19,7 +19,7 @@ from khoj.utils.config import ( ) from khoj.utils.helpers import resolve_absolute_path, merge_dicts from khoj.utils.fs_syncer import collect_files -from khoj.utils.rawconfig import FullConfig, ProcessorConfig, ConversationProcessorConfig +from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig from khoj.routers.indexer import configure_content, load_content, configure_search @@ -168,9 +168,7 @@ def configure_conversation_processor( conversation_config=ConversationProcessorConfig( conversation_logfile=conversation_logfile, openai=(conversation_config.openai if (conversation_config is not None) else None), - enable_offline_chat=( - conversation_config.enable_offline_chat if (conversation_config is not None) else False - ), + offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(), ) ) else: diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 3b295a88..d41ca26b 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -236,7 +236,7 @@
-

Setup chat using OpenAI

+

Setup online chat using OpenAI

-

Setup offline chat (Llama V2)

+

Setup offline chat

-
+
-
+
@@ -346,7 +346,7 @@ featuresHintText.classList.add("show"); } - fetch('/api/config/data/processor/conversation/enable_offline_chat' + '?enable_offline_chat=' + enable, { + fetch('/api/config/data/processor/conversation/offline_chat' + '?enable_offline_chat=' + enable, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/khoj/migrations/migrate_offline_chat_schema.py b/src/khoj/migrations/migrate_offline_chat_schema.py new file mode 100644 index 00000000..873783a3 --- /dev/null +++ b/src/khoj/migrations/migrate_offline_chat_schema.py @@ -0,0 +1,83 @@ +""" +Current format of khoj.yml +--- +app: + ... +content-type: + ... +processor: + conversation: + enable-offline-chat: false + conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json + openai: + ... +search-type: + ... + +New format of khoj.yml +--- +app: + ... +content-type: + ... +processor: + conversation: + offline-chat: + enable-offline-chat: false + chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin + tokenizer: null + max_prompt_size: null + conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json + openai: + ... +search-type: + ... +""" +import logging +from packaging import version + +from khoj.utils.yaml import load_config_from_file, save_config_to_file + + +logger = logging.getLogger(__name__) + + +def migrate_offline_chat_schema(args): + schema_version = "0.12.3" + raw_config = load_config_from_file(args.config_file) + previous_version = raw_config.get("version") + + if "processor" not in raw_config: + return args + if raw_config["processor"] is None: + return args + if "conversation" not in raw_config["processor"]: + return args + + if previous_version is None or version.parse(previous_version) < version.parse("0.12.3"): + logger.info( + f"Upgrading config schema to {schema_version} from {previous_version} to make (offline) chat more configuration" + ) + raw_config["version"] = schema_version + + # Create max-prompt-size field in conversation processor schema + raw_config["processor"]["conversation"]["max-prompt-size"] = None + raw_config["processor"]["conversation"]["tokenizer"] = None + + # Create offline chat schema based on existing enable_offline_chat field in khoj config schema + offline_chat_model = ( + raw_config["processor"]["conversation"] + .get("offline-chat", {}) + .get("chat-model", "llama-2-7b-chat.ggmlv3.q4_0.bin") + ) + raw_config["processor"]["conversation"]["offline-chat"] = { + "enable-offline-chat": raw_config["processor"]["conversation"].get("enable-offline-chat", False), + "chat-model": offline_chat_model, + } + + # Delete old enable-offline-chat field from conversation processor schema + if "enable-offline-chat" in raw_config["processor"]["conversation"]: + del raw_config["processor"]["conversation"]["enable-offline-chat"] + + save_config_to_file(raw_config, args.config_file) + return args diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 2ff6bab0..91db7c58 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -284,7 +284,7 @@ if not state.demo: except Exception as e: return {"status": "error", "message": str(e)} - @api.post("/config/data/processor/conversation/enable_offline_chat", status_code=200) + @api.post("/config/data/processor/conversation/offline_chat", status_code=200) async def set_processor_enable_offline_chat_config_data( request: Request, enable_offline_chat: bool, @@ -301,7 +301,7 @@ if not state.demo: state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore assert state.config.processor.conversation is not None - state.config.processor.conversation.enable_offline_chat = enable_offline_chat + state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat state.processor_config = configure_processor(state.config.processor, state.processor_config) update_telemetry_state( @@ -707,7 +707,7 @@ async def chat( ) conversation_command = get_conversation_command(query=q, any_references=not is_none_or_empty(compiled_references)) if conversation_command == ConversationCommand.Help: - model_type = "offline" if state.processor_config.conversation.enable_offline_chat else "openai" + model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai" formatted_help = help_message.format(model=model_type, version=state.khoj_version) return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200) @@ -784,7 +784,7 @@ async def extract_references_and_questions( # Infer search queries from user message with timer("Extracting search queries took", logger): # If we've reached here, either the user has enabled offline chat or the openai model is enabled. - if state.processor_config.conversation.enable_offline_chat: + if state.processor_config.conversation.offline_chat.enable_offline_chat: loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model inferred_queries = extract_questions_offline( defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False @@ -800,7 +800,7 @@ async def extract_references_and_questions( with timer("Searching knowledge base took", logger): result_list = [] for query in inferred_queries: - n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n + n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n result_list.extend( await search( f"{query} {filters_in_query}", diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 3898d1b8..0bc66991 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -113,7 +113,7 @@ def generate_chat_response( meta_log=meta_log, ) - if state.processor_config.conversation.enable_offline_chat: + if state.processor_config.conversation.offline_chat.enable_offline_chat: loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model chat_response = converse_offline( references=compiled_references, diff --git a/src/khoj/utils/cli.py b/src/khoj/utils/cli.py index 78a9ccf9..1d6106cb 100644 --- a/src/khoj/utils/cli.py +++ b/src/khoj/utils/cli.py @@ -9,6 +9,7 @@ from khoj.utils.yaml import parse_config_from_file from khoj.migrations.migrate_version import migrate_config_to_version from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema from khoj.migrations.migrate_offline_model import migrate_offline_model +from khoj.migrations.migrate_offline_chat_schema import migrate_offline_chat_schema def cli(args=None): @@ -55,7 +56,12 @@ def cli(args=None): def run_migrations(args): - migrations = [migrate_config_to_version, migrate_processor_conversation_schema, migrate_offline_model] + migrations = [ + migrate_config_to_version, + migrate_processor_conversation_schema, + migrate_offline_model, + migrate_offline_chat_schema, + ] for migration in migrations: args = migration(args) return args diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 5accd2ad..90e8862a 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -96,18 +96,18 @@ class ConversationProcessorConfigModel: self.openai_model = conversation_config.openai self.gpt4all_model = GPT4AllProcessorConfig() self.gpt4all_model.chat_model = conversation_config.offline_chat_model - self.enable_offline_chat = conversation_config.enable_offline_chat + self.offline_chat = conversation_config.offline_chat self.conversation_logfile = Path(conversation_config.conversation_logfile) self.chat_session: List[str] = [] self.meta_log: dict = {} - if self.enable_offline_chat: + if self.offline_chat.enable_offline_chat: try: self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model) except ValueError as e: + self.offline_chat.enable_offline_chat = False self.gpt4all_model.loaded_model = None logger.error(f"Error while loading offline chat model: {e}", exc_info=True) - self.enable_offline_chat = False else: self.gpt4all_model.loaded_model = None diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 30a98354..f7c42266 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -91,11 +91,17 @@ class OpenAIProcessorConfig(ConfigBase): chat_model: Optional[str] = "gpt-3.5-turbo" +class OfflineChatProcessorConfig(ConfigBase): + enable_offline_chat: Optional[bool] = False + chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin" + + class ConversationProcessorConfig(ConfigBase): conversation_logfile: Path openai: Optional[OpenAIProcessorConfig] - enable_offline_chat: Optional[bool] = False - offline_chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin" + offline_chat: Optional[OfflineChatProcessorConfig] + max_prompt_size: Optional[int] + tokenizer: Optional[str] class ProcessorConfig(ConfigBase): diff --git a/tests/conftest.py b/tests/conftest.py index d851341d..f75dfceb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,7 @@ from khoj.utils.helpers import resolve_absolute_path from khoj.utils.rawconfig import ( ContentConfig, ConversationProcessorConfig, + OfflineChatProcessorConfig, OpenAIProcessorConfig, ProcessorConfig, TextContentConfig, @@ -205,8 +206,9 @@ def processor_config_offline_chat(tmp_path_factory): # Setup conversation processor processor_config = ProcessorConfig() + offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True) processor_config.conversation = ConversationProcessorConfig( - enable_offline_chat=True, + offline_chat=offline_chat, conversation_logfile=processor_dir.joinpath("conversation_logs.json"), ) From 116595b351d1dfeeaaa7399d25cbb32c064eeafa Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 15 Oct 2023 14:24:28 -0700 Subject: [PATCH 08/12] Use chat_model specified in new offline_chat section of config - Dedupe offline_chat_model variable. Only reference offline chat model stored under offline_chat. Delete the previous chat_model field under GPT4AllProcessorConfig - Set offline chat model to use via config/offline_chat API endpoint --- src/khoj/routers/api.py | 3 +++ src/khoj/routers/helpers.py | 2 +- src/khoj/utils/config.py | 4 +--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 91db7c58..8dc0a37e 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -288,6 +288,7 @@ if not state.demo: async def set_processor_enable_offline_chat_config_data( request: Request, enable_offline_chat: bool, + offline_chat_model: Optional[str] = None, client: Optional[str] = None, ): _initialize_config() @@ -302,6 +303,8 @@ if not state.demo: assert state.config.processor.conversation is not None state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat + if offline_chat_model is not None: + state.config.processor.conversation.offline_chat.chat_model = offline_chat_model state.processor_config = configure_processor(state.config.processor, state.processor_config) update_telemetry_state( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 0bc66991..d8b0aa8b 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -122,7 +122,7 @@ def generate_chat_response( conversation_log=meta_log, completion_func=partial_completion, conversation_command=conversation_command, - model=state.processor_config.conversation.gpt4all_model.chat_model, + model=state.processor_config.conversation.offline_chat.chat_model, ) elif state.processor_config.conversation.openai_model: diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 90e8862a..daae1982 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -84,7 +84,6 @@ class SearchModels: @dataclass class GPT4AllProcessorConfig: - chat_model: Optional[str] = None loaded_model: Union[Any, None] = None @@ -95,7 +94,6 @@ class ConversationProcessorConfigModel: ): self.openai_model = conversation_config.openai self.gpt4all_model = GPT4AllProcessorConfig() - self.gpt4all_model.chat_model = conversation_config.offline_chat_model self.offline_chat = conversation_config.offline_chat self.conversation_logfile = Path(conversation_config.conversation_logfile) self.chat_session: List[str] = [] @@ -103,7 +101,7 @@ class ConversationProcessorConfigModel: if self.offline_chat.enable_offline_chat: try: - self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model) + self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model) except ValueError as e: self.offline_chat.enable_offline_chat = False self.gpt4all_model.loaded_model = None From df1d74a879d5b62ab983bcbba8d9bee1c5fce03f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 15 Oct 2023 16:33:26 -0700 Subject: [PATCH 09/12] Use max_prompt_size, tokenizer from config for chat model context stuffing --- .../conversation/gpt4all/chat_model.py | 4 ++ src/khoj/processor/conversation/openai/gpt.py | 4 ++ src/khoj/processor/conversation/utils.py | 45 ++++++++++++++----- src/khoj/routers/helpers.py | 4 ++ src/khoj/utils/config.py | 2 + 5 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index e9beaa80..7e92d002 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -127,6 +127,8 @@ def converse_offline( loaded_model: Union[Any, None] = None, completion_func=None, conversation_command=ConversationCommand.Default, + max_prompt_size=None, + tokenizer_name=None, ) -> Union[ThreadedGenerator, Iterator[str]]: """ Converse with user using Llama @@ -158,6 +160,8 @@ def converse_offline( prompts.system_prompt_message_llamav2, conversation_log, model_name=model, + max_prompt_size=max_prompt_size, + tokenizer_name=tokenizer_name, ) g = ThreadedGenerator(references, completion_func=completion_func) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 96510586..73b4f176 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -116,6 +116,8 @@ def converse( temperature: float = 0.2, completion_func=None, conversation_command=ConversationCommand.Default, + max_prompt_size=None, + tokenizer_name=None, ): """ Converse with user using OpenAI's ChatGPT @@ -141,6 +143,8 @@ def converse( prompts.personality.format(), conversation_log, model, + max_prompt_size, + tokenizer_name, ) truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) logger.debug(f"Conversation Context for GPT: {truncated_messages}") diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 7bb86887..5f219b83 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -13,17 +13,16 @@ from transformers import AutoTokenizer import queue from khoj.utils.helpers import merge_dicts + logger = logging.getLogger(__name__) -max_prompt_size = { +model_to_prompt_size = { "gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548, "gpt-3.5-turbo-16k": 15000, - "default": 1600, } -tokenizer = { +model_to_tokenizer = { "llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer", - "default": "hf-internal-testing/llama-tokenizer", } @@ -86,7 +85,13 @@ def message_to_log( def generate_chatml_messages_with_context( - user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2 + user_message, + system_message, + conversation_log={}, + model_name="gpt-3.5-turbo", + lookback_turns=2, + max_prompt_size=None, + tokenizer_name=None, ): """Generate messages for ChatGPT with context from previous conversation""" # Extract Chat History for Context @@ -108,20 +113,38 @@ def generate_chatml_messages_with_context( messages = user_chatml_message + rest_backnforths + system_chatml_message + # Set max prompt size from user config, pre-configured for model or to default prompt size + try: + max_prompt_size = max_prompt_size or model_to_prompt_size[model_name] + except: + max_prompt_size = 2000 + logger.warning( + f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window." + ) + # Truncate oldest messages from conversation history until under max supported prompt size by model - messages = truncate_messages(messages, max_prompt_size.get(model_name, max_prompt_size["default"]), model_name) + messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name) # Return message in chronological order return messages[::-1] -def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name: str) -> list[ChatMessage]: +def truncate_messages( + messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None +) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" - if model_name.startswith("gpt-"): - encoder = tiktoken.encoding_for_model(model_name) - else: - encoder = AutoTokenizer.from_pretrained(tokenizer.get(model_name, tokenizer["default"])) + try: + if model_name.startswith("gpt-"): + encoder = tiktoken.encoding_for_model(model_name) + else: + encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name]) + except: + default_tokenizer = "hf-internal-testing/llama-tokenizer" + encoder = AutoTokenizer.from_pretrained(default_tokenizer) + logger.warning( + f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." + ) system_message = messages.pop() system_message_tokens = len(encoder.encode(system_message.content)) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d8b0aa8b..6b42f29c 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -123,6 +123,8 @@ def generate_chat_response( completion_func=partial_completion, conversation_command=conversation_command, model=state.processor_config.conversation.offline_chat.chat_model, + max_prompt_size=state.processor_config.conversation.max_prompt_size, + tokenizer_name=state.processor_config.conversation.tokenizer, ) elif state.processor_config.conversation.openai_model: @@ -136,6 +138,8 @@ def generate_chat_response( api_key=api_key, completion_func=partial_completion, conversation_command=conversation_command, + max_prompt_size=state.processor_config.conversation.max_prompt_size, + tokenizer_name=state.processor_config.conversation.tokenizer, ) except Exception as e: diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index daae1982..3930ec98 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -95,6 +95,8 @@ class ConversationProcessorConfigModel: self.openai_model = conversation_config.openai self.gpt4all_model = GPT4AllProcessorConfig() self.offline_chat = conversation_config.offline_chat + self.max_prompt_size = conversation_config.max_prompt_size + self.tokenizer = conversation_config.tokenizer self.conversation_logfile = Path(conversation_config.conversation_logfile) self.chat_session: List[str] = [] self.meta_log: dict = {} From 1a9023d3968e9e7ae079dbcf6ee0105209f8d621 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 15 Oct 2023 17:22:44 -0700 Subject: [PATCH 10/12] Update Chat Actor test to not incept with prior world knowledge --- tests/test_gpt4all_chat_actors.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py index 056618be..76ed26e7 100644 --- a/tests/test_gpt4all_chat_actors.py +++ b/tests/test_gpt4all_chat_actors.py @@ -145,7 +145,7 @@ def test_extract_multiple_implicit_questions_from_message(loaded_model): def test_generate_search_query_using_question_from_chat_history(loaded_model): # Arrange message_list = [ - ("What is the name of Mr. Vader's daughter?", "Princess Leia", []), + ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []), ] # Act @@ -156,17 +156,22 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model): use_history=True, ) - expected_responses = [ - "Vader", - "sons", + all_expected_in_response = [ + "Anderson", + ] + + any_expected_in_response = [ "son", - "Darth", + "sons", "children", ] # Assert assert len(response) >= 1 - assert any([expected_response in response[0] for expected_response in expected_responses]), ( + assert all([expected_response in response[0] for expected_response in all_expected_in_response]), ( + "Expected chat actor to ask for clarification in response, but got: " + response[0] + ) + assert any([expected_response in response[0] for expected_response in any_expected_in_response]), ( "Expected chat actor to ask for clarification in response, but got: " + response[0] ) @@ -176,20 +181,20 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model): def test_generate_search_query_using_answer_from_chat_history(loaded_model): # Arrange message_list = [ - ("What is the name of Mr. Vader's daughter?", "Princess Leia", []), + ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []), ] # Act response = extract_questions_offline( - "Is she a Jedi?", + "Is she a Doctor?", conversation_log=populate_chat_history(message_list), loaded_model=loaded_model, use_history=True, ) expected_responses = [ - "Leia", - "Vader", + "Barbara", + "Robert", "daughter", ] From 90e1d9e3d685f4f6c54835f5092c88c6a252b61e Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 16 Oct 2023 10:57:16 -0700 Subject: [PATCH 11/12] Pin gpt4all to 1.0.12 as next version will introduce breaking changes --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e6773b88..a52fc9b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,8 +59,8 @@ dependencies = [ "bs4 >= 0.0.1", "anyio == 3.7.1", "pymupdf >= 1.23.3", - "gpt4all >= 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'", - "gpt4all >= 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'", + "gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'", + "gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'", ] dynamic = ["version"] From 644c3b787f12bbc2d3f4814bd4afc5fd82c9e099 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 16 Oct 2023 11:15:38 -0700 Subject: [PATCH 12/12] Scale no. of chat history messages to use as context with max_prompt_size Previously lookback turns was set to a static 2. But now that we support more chat models, their prompt size vary considerably. Make lookback_turns proportional to max_prompt_size. The truncate_messages can remove messages if they exceed max_prompt_size later This lets Khoj pass more of the chat history as context for models with larger context window --- src/khoj/processor/conversation/utils.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 5f219b83..83d51f2d 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -3,6 +3,7 @@ import logging from time import perf_counter import json from datetime import datetime +import queue import tiktoken # External packages @@ -10,7 +11,6 @@ from langchain.schema import ChatMessage from transformers import AutoTokenizer # Internal Packages -import queue from khoj.utils.helpers import merge_dicts @@ -89,11 +89,22 @@ def generate_chatml_messages_with_context( system_message, conversation_log={}, model_name="gpt-3.5-turbo", - lookback_turns=2, max_prompt_size=None, tokenizer_name=None, ): """Generate messages for ChatGPT with context from previous conversation""" + # Set max prompt size from user config, pre-configured for model or to default prompt size + try: + max_prompt_size = max_prompt_size or model_to_prompt_size[model_name] + except: + max_prompt_size = 2000 + logger.warning( + f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window." + ) + + # Scale lookback turns proportional to max prompt size supported by model + lookback_turns = max_prompt_size // 750 + # Extract Chat History for Context chat_logs = [] for chat in conversation_log.get("chat", []): @@ -113,15 +124,6 @@ def generate_chatml_messages_with_context( messages = user_chatml_message + rest_backnforths + system_chatml_message - # Set max prompt size from user config, pre-configured for model or to default prompt size - try: - max_prompt_size = max_prompt_size or model_to_prompt_size[model_name] - except: - max_prompt_size = 2000 - logger.warning( - f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window." - ) - # Truncate oldest messages from conversation history until under max supported prompt size by model messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)