Support setting seed for reproducible LLM response generation

Anthropic models do not support seed. But offline, gemini and openai
models do. Use these to debug and test Khoj via KHOJ_LLM_SEED env var
This commit is contained in:
Debanjum 2024-10-30 02:20:38 -07:00
parent d44e68ba01
commit b3a63017b5
4 changed files with 17 additions and 2 deletions

View file

@ -87,7 +87,7 @@ dependencies = [
"django_apscheduler == 0.6.2", "django_apscheduler == 0.6.2",
"anthropic == 0.26.1", "anthropic == 0.26.1",
"docx2txt == 0.8", "docx2txt == 0.8",
"google-generativeai == 0.7.2" "google-generativeai == 0.8.3"
] ]
dynamic = ["version"] dynamic = ["version"]

View file

@ -1,4 +1,5 @@
import logging import logging
import os
import random import random
from threading import Thread from threading import Thread

View file

@ -1,5 +1,6 @@
import json import json
import logging import logging
import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
from threading import Thread from threading import Thread
from typing import Any, Iterator, List, Optional, Union from typing import Any, Iterator, List, Optional, Union
@ -265,8 +266,14 @@ def send_message_to_model_offline(
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages] messages_dict = [{"role": message.role, "content": message.content} for message in messages]
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
response = offline_chat_model.create_chat_completion( response = offline_chat_model.create_chat_completion(
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type} messages_dict,
stop=stop,
stream=streaming,
temperature=temperature,
response_format={"type": response_type},
seed=seed,
) )
if streaming: if streaming:

View file

@ -1,4 +1,5 @@
import logging import logging
import os
from threading import Thread from threading import Thread
from typing import Dict from typing import Dict
@ -60,6 +61,9 @@ def completion_with_backoff(
model_kwargs.pop("stop", None) model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None) model_kwargs.pop("response_format", None)
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create( chat = client.chat.completions.create(
stream=stream, stream=stream,
messages=formatted_messages, # type: ignore messages=formatted_messages, # type: ignore
@ -157,6 +161,9 @@ def llm_thread(
model_kwargs.pop("stop", None) model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None) model_kwargs.pop("response_format", None)
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create( chat = client.chat.completions.create(
stream=stream, stream=stream,
messages=formatted_messages, messages=formatted_messages,