mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Fix context, response size for Llama 2 to stay within max token limits
Create regression text to ensure it does not throw the prompt size exceeded context window error
This commit is contained in:
parent
6e4050fa81
commit
c2b7a14ed5
3 changed files with 24 additions and 2 deletions
|
@ -165,7 +165,7 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
|
|||
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
|
||||
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000, n_batch=256)
|
||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=256)
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
for response in response_iterator:
|
||||
|
|
|
@ -14,7 +14,7 @@ import queue
|
|||
from khoj.utils.helpers import merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 2048}
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548}
|
||||
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
|
||||
|
||||
|
||||
|
|
|
@ -454,6 +454,28 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
|||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_chat_does_not_exceed_prompt_size(loaded_model):
|
||||
"Ensure chat context and response together do not exceed max prompt size for the model"
|
||||
# Arrange
|
||||
prompt_size_exceeded_error = "ERROR: The prompt size exceeds the context window size and cannot be processed"
|
||||
context = [" ".join([f"{number}" for number in range(2043)])]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What numbers come after these?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert prompt_size_exceeded_error not in response, (
|
||||
"Expected chat response to be within prompt limits, but got exceeded error: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_filter_questions():
|
||||
test_questions = [
|
||||
"I don't know how to answer that",
|
||||
|
|
Loading…
Reference in a new issue