From 53eabe0c06bda4d8086baee0285d773a505531d6 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 9 Jul 2024 17:31:30 +0530 Subject: [PATCH] Support Gemma 2 for Offline Chat - Pass system message as the first user chat message as Gemma 2 doesn't support system messages - Use gemma-2 chat format - Pass chat model name to generic, extract questions chat actors Used to figure out chat template to use for model For generic chat actor argument was anyway available but not being passed, which is confusing --- pyproject.toml | 2 +- src/khoj/processor/conversation/offline/chat_model.py | 2 +- src/khoj/processor/conversation/offline/utils.py | 2 ++ src/khoj/processor/conversation/utils.py | 2 ++ src/khoj/routers/api.py | 1 + 5 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2669f5ff..d41d7977 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ dependencies = [ "pymupdf >= 1.23.5", "django == 5.0.7", "authlib == 1.2.1", - "llama-cpp-python == 0.2.76", + "llama-cpp-python == 0.2.82", "itsdangerous == 2.1.2", "httpx == 0.25.0", "pgvector == 0.2.4", diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index edc2d9f0..0979e326 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -74,7 +74,7 @@ def extract_questions_offline( state.chat_lock.acquire() try: response = send_message_to_model_offline( - messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size + messages, loaded_model=offline_chat_model, model=model, max_prompt_size=max_prompt_size ) finally: state.chat_lock.release() diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index 05de4b9f..66017b36 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -24,6 +24,8 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int # Add chat format if known if "llama-3" in repo_id.lower(): kwargs["chat_format"] = "llama-3" + elif "gemma-2" in repo_id.lower(): + kwargs["chat_format"] = "gemma" # Check if the model is already downloaded model_path = load_model_from_cache(repo_id, filename) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 5d68d17d..c005dde7 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -254,6 +254,8 @@ def truncate_messages( f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}" ) + if system_message: + system_message.role = "user" if "gemma-2" in model_name else "system" return messages + [system_message] if system_message else messages diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index cbe19891..20a6bc09 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -335,6 +335,7 @@ async def extract_references_and_questions( inferred_queries = extract_questions_offline( defiltered_query, + model=chat_model, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=True,