diff --git a/pyproject.toml b/pyproject.toml index 12c7789c..ed57f55a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ dependencies = [ "django_apscheduler == 0.6.2", "anthropic == 0.26.1", "docx2txt == 0.8", - "google-generativeai == 0.7.2" + "google-generativeai == 0.8.3" ] dynamic = ["version"] diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 7b848324..ccde33d1 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -1,4 +1,5 @@ import logging +import os import random from threading import Thread diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index fbde7a98..cbfea6fd 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -1,5 +1,6 @@ import json import logging +import os from datetime import datetime, timedelta from threading import Thread from typing import Any, Iterator, List, Optional, Union @@ -265,8 +266,14 @@ def send_message_to_model_offline( assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) messages_dict = [{"role": message.role, "content": message.content} for message in messages] + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None response = offline_chat_model.create_chat_completion( - messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type} + messages_dict, + stop=stop, + stream=streaming, + temperature=temperature, + response_format={"type": response_type}, + seed=seed, ) if streaming: diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 6e519f5a..36ebc679 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -1,4 +1,5 @@ import logging +import os from threading import Thread from typing import Dict @@ -60,6 +61,9 @@ def completion_with_backoff( model_kwargs.pop("stop", None) model_kwargs.pop("response_format", None) + if os.getenv("KHOJ_LLM_SEED"): + model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) + chat = client.chat.completions.create( stream=stream, messages=formatted_messages, # type: ignore @@ -157,6 +161,9 @@ def llm_thread( model_kwargs.pop("stop", None) model_kwargs.pop("response_format", None) + if os.getenv("KHOJ_LLM_SEED"): + model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) + chat = client.chat.completions.create( stream=stream, messages=formatted_messages,