Fix context, response size for Llama 2 to stay within max token limits

Create regression text to ensure it does not throw the prompt size
exceeded context window error
This commit is contained in:
Debanjum Singh Solanky 2023-08-01 19:29:03 -07:00
parent 6e4050fa81
commit c2b7a14ed5
3 changed files with 24 additions and 2 deletions

View file

@ -165,7 +165,7 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content) templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content) templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
prompted_message = templated_system_message + chat_history + templated_user_message prompted_message = templated_system_message + chat_history + templated_user_message
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000, n_batch=256) response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=256)
state.chat_lock.acquire() state.chat_lock.acquire()
try: try:
for response in response_iterator: for response in response_iterator:

View file

@ -14,7 +14,7 @@ import queue
from khoj.utils.helpers import merge_dicts from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 2048} max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548}
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"} tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}

View file

@ -454,6 +454,28 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
) )
# ----------------------------------------------------------------------------------------------------
def test_chat_does_not_exceed_prompt_size(loaded_model):
"Ensure chat context and response together do not exceed max prompt size for the model"
# Arrange
prompt_size_exceeded_error = "ERROR: The prompt size exceeds the context window size and cannot be processed"
context = [" ".join([f"{number}" for number in range(2043)])]
# Act
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="What numbers come after these?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert prompt_size_exceeded_error not in response, (
"Expected chat response to be within prompt limits, but got exceeded error: " + response
)
# ----------------------------------------------------------------------------------------------------
def test_filter_questions(): def test_filter_questions():
test_questions = [ test_questions = [
"I don't know how to answer that", "I don't know how to answer that",