From c2b7a14ed50d02a99eccae273b467c1c9096f55d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 1 Aug 2023 19:29:03 -0700 Subject: [PATCH] Fix context, response size for Llama 2 to stay within max token limits Create regression text to ensure it does not throw the prompt size exceeded context window error --- .../conversation/gpt4all/chat_model.py | 2 +- src/khoj/processor/conversation/utils.py | 2 +- tests/test_gpt4all_chat_actors.py | 22 +++++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index 9ca5a1b8..d0c4ff31 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -165,7 +165,7 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All): templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content) templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content) prompted_message = templated_system_message + chat_history + templated_user_message - response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000, n_batch=256) + response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=256) state.chat_lock.acquire() try: for response in response_iterator: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e414b35b..7bcac2d8 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -14,7 +14,7 @@ import queue from khoj.utils.helpers import merge_dicts logger = logging.getLogger(__name__) -max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 2048} +max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548} tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"} diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py index a7191a66..32e7e941 100644 --- a/tests/test_gpt4all_chat_actors.py +++ b/tests/test_gpt4all_chat_actors.py @@ -454,6 +454,28 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."" ) +# ---------------------------------------------------------------------------------------------------- +def test_chat_does_not_exceed_prompt_size(loaded_model): + "Ensure chat context and response together do not exceed max prompt size for the model" + # Arrange + prompt_size_exceeded_error = "ERROR: The prompt size exceeds the context window size and cannot be processed" + context = [" ".join([f"{number}" for number in range(2043)])] + + # Act + response_gen = converse_offline( + references=context, # Assume context retrieved from notes for the user_query + user_query="What numbers come after these?", + loaded_model=loaded_model, + ) + response = "".join([response_chunk for response_chunk in response_gen]) + + # Assert + assert prompt_size_exceeded_error not in response, ( + "Expected chat response to be within prompt limits, but got exceeded error: " + response + ) + + +# ---------------------------------------------------------------------------------------------------- def test_filter_questions(): test_questions = [ "I don't know how to answer that",