Support Gemma 2 for Offline Chat

- Pass system message as the first user chat message as Gemma 2
  doesn't support system messages
- Use gemma-2 chat format
- Pass chat model name to generic, extract questions chat actors
  Used to figure out chat template to use for model
  For generic chat actor argument was anyway available but not being
  passed, which is confusing
This commit is contained in:
Debanjum Singh Solanky 2024-07-09 17:31:30 +05:30
parent 2ab8fb78b1
commit 53eabe0c06
5 changed files with 7 additions and 2 deletions

View file

@ -66,7 +66,7 @@ dependencies = [
"pymupdf >= 1.23.5",
"django == 5.0.7",
"authlib == 1.2.1",
"llama-cpp-python == 0.2.76",
"llama-cpp-python == 0.2.82",
"itsdangerous == 2.1.2",
"httpx == 0.25.0",
"pgvector == 0.2.4",

View file

@ -74,7 +74,7 @@ def extract_questions_offline(
state.chat_lock.acquire()
try:
response = send_message_to_model_offline(
messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
messages, loaded_model=offline_chat_model, model=model, max_prompt_size=max_prompt_size
)
finally:
state.chat_lock.release()

View file

@ -24,6 +24,8 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int
# Add chat format if known
if "llama-3" in repo_id.lower():
kwargs["chat_format"] = "llama-3"
elif "gemma-2" in repo_id.lower():
kwargs["chat_format"] = "gemma"
# Check if the model is already downloaded
model_path = load_model_from_cache(repo_id, filename)

View file

@ -254,6 +254,8 @@ def truncate_messages(
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
)
if system_message:
system_message.role = "user" if "gemma-2" in model_name else "system"
return messages + [system_message] if system_message else messages

View file

@ -335,6 +335,7 @@ async def extract_references_and_questions(
inferred_queries = extract_questions_offline(
defiltered_query,
model=chat_model,
loaded_model=loaded_model,
conversation_log=meta_log,
should_extract_questions=True,