mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Support Gemma 2 for Offline Chat
- Pass system message as the first user chat message as Gemma 2 doesn't support system messages - Use gemma-2 chat format - Pass chat model name to generic, extract questions chat actors Used to figure out chat template to use for model For generic chat actor argument was anyway available but not being passed, which is confusing
This commit is contained in:
parent
2ab8fb78b1
commit
53eabe0c06
5 changed files with 7 additions and 2 deletions
|
@ -66,7 +66,7 @@ dependencies = [
|
|||
"pymupdf >= 1.23.5",
|
||||
"django == 5.0.7",
|
||||
"authlib == 1.2.1",
|
||||
"llama-cpp-python == 0.2.76",
|
||||
"llama-cpp-python == 0.2.82",
|
||||
"itsdangerous == 2.1.2",
|
||||
"httpx == 0.25.0",
|
||||
"pgvector == 0.2.4",
|
||||
|
|
|
@ -74,7 +74,7 @@ def extract_questions_offline(
|
|||
state.chat_lock.acquire()
|
||||
try:
|
||||
response = send_message_to_model_offline(
|
||||
messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
|
||||
messages, loaded_model=offline_chat_model, model=model, max_prompt_size=max_prompt_size
|
||||
)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
|
|
|
@ -24,6 +24,8 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int
|
|||
# Add chat format if known
|
||||
if "llama-3" in repo_id.lower():
|
||||
kwargs["chat_format"] = "llama-3"
|
||||
elif "gemma-2" in repo_id.lower():
|
||||
kwargs["chat_format"] = "gemma"
|
||||
|
||||
# Check if the model is already downloaded
|
||||
model_path = load_model_from_cache(repo_id, filename)
|
||||
|
|
|
@ -254,6 +254,8 @@ def truncate_messages(
|
|||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||
)
|
||||
|
||||
if system_message:
|
||||
system_message.role = "user" if "gemma-2" in model_name else "system"
|
||||
return messages + [system_message] if system_message else messages
|
||||
|
||||
|
||||
|
|
|
@ -335,6 +335,7 @@ async def extract_references_and_questions(
|
|||
|
||||
inferred_queries = extract_questions_offline(
|
||||
defiltered_query,
|
||||
model=chat_model,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
should_extract_questions=True,
|
||||
|
|
Loading…
Reference in a new issue