mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Fix context, response size for Llama 2 to stay within max token limits
Create regression text to ensure it does not throw the prompt size exceeded context window error
This commit is contained in:
parent
6e4050fa81
commit
c2b7a14ed5
3 changed files with 24 additions and 2 deletions
|
@ -165,7 +165,7 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
|
||||||
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
|
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
|
||||||
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
||||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000, n_batch=256)
|
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=256)
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
try:
|
try:
|
||||||
for response in response_iterator:
|
for response in response_iterator:
|
||||||
|
|
|
@ -14,7 +14,7 @@ import queue
|
||||||
from khoj.utils.helpers import merge_dicts
|
from khoj.utils.helpers import merge_dicts
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 2048}
|
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548}
|
||||||
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
|
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -454,6 +454,28 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def test_chat_does_not_exceed_prompt_size(loaded_model):
|
||||||
|
"Ensure chat context and response together do not exceed max prompt size for the model"
|
||||||
|
# Arrange
|
||||||
|
prompt_size_exceeded_error = "ERROR: The prompt size exceeds the context window size and cannot be processed"
|
||||||
|
context = [" ".join([f"{number}" for number in range(2043)])]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response_gen = converse_offline(
|
||||||
|
references=context, # Assume context retrieved from notes for the user_query
|
||||||
|
user_query="What numbers come after these?",
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
)
|
||||||
|
response = "".join([response_chunk for response_chunk in response_gen])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert prompt_size_exceeded_error not in response, (
|
||||||
|
"Expected chat response to be within prompt limits, but got exceeded error: " + response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_filter_questions():
|
def test_filter_questions():
|
||||||
test_questions = [
|
test_questions = [
|
||||||
"I don't know how to answer that",
|
"I don't know how to answer that",
|
||||||
|
|
Loading…
Reference in a new issue