diff --git a/pyproject.toml b/pyproject.toml index 2669f5ff..d41d7977 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ dependencies = [ "pymupdf >= 1.23.5", "django == 5.0.7", "authlib == 1.2.1", - "llama-cpp-python == 0.2.76", + "llama-cpp-python == 0.2.82", "itsdangerous == 2.1.2", "httpx == 0.25.0", "pgvector == 0.2.4", diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index edc2d9f0..0979e326 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -74,7 +74,7 @@ def extract_questions_offline( state.chat_lock.acquire() try: response = send_message_to_model_offline( - messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size + messages, loaded_model=offline_chat_model, model=model, max_prompt_size=max_prompt_size ) finally: state.chat_lock.release() diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index 05de4b9f..66017b36 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -24,6 +24,8 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int # Add chat format if known if "llama-3" in repo_id.lower(): kwargs["chat_format"] = "llama-3" + elif "gemma-2" in repo_id.lower(): + kwargs["chat_format"] = "gemma" # Check if the model is already downloaded model_path = load_model_from_cache(repo_id, filename) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 5d68d17d..c005dde7 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -254,6 +254,8 @@ def truncate_messages( f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}" ) + if system_message: + system_message.role = "user" if "gemma-2" in model_name else "system" return messages + [system_message] if system_message else messages diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index cbe19891..20a6bc09 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -335,6 +335,7 @@ async def extract_references_and_questions( inferred_queries = extract_questions_offline( defiltered_query, + model=chat_model, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=True,