Support setting seed for reproducible LLM response generation

Anthropic models do not support seed. But offline, gemini and openai
models do. Use these to debug and test Khoj via KHOJ_LLM_SEED env var
This commit is contained in:
Debanjum 2024-10-30 02:20:38 -07:00
parent d44e68ba01
commit b3a63017b5
4 changed files with 17 additions and 2 deletions

View file

@ -87,7 +87,7 @@ dependencies = [
"django_apscheduler == 0.6.2",
"anthropic == 0.26.1",
"docx2txt == 0.8",
"google-generativeai == 0.7.2"
"google-generativeai == 0.8.3"
]
dynamic = ["version"]

View file

@ -1,4 +1,5 @@
import logging
import os
import random
from threading import Thread

View file

@ -1,5 +1,6 @@
import json
import logging
import os
from datetime import datetime, timedelta
from threading import Thread
from typing import Any, Iterator, List, Optional, Union
@ -265,8 +266,14 @@ def send_message_to_model_offline(
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
response = offline_chat_model.create_chat_completion(
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
messages_dict,
stop=stop,
stream=streaming,
temperature=temperature,
response_format={"type": response_type},
seed=seed,
)
if streaming:

View file

@ -1,4 +1,5 @@
import logging
import os
from threading import Thread
from typing import Dict
@ -60,6 +61,9 @@ def completion_with_backoff(
model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None)
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create(
stream=stream,
messages=formatted_messages, # type: ignore
@ -157,6 +161,9 @@ def llm_thread(
model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None)
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create(
stream=stream,
messages=formatted_messages,