mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 09:25:06 +01:00
Support setting seed for reproducible LLM response generation
Anthropic models do not support seed. But offline, gemini and openai models do. Use these to debug and test Khoj via KHOJ_LLM_SEED env var
This commit is contained in:
parent
d44e68ba01
commit
b3a63017b5
4 changed files with 17 additions and 2 deletions
|
@ -87,7 +87,7 @@ dependencies = [
|
|||
"django_apscheduler == 0.6.2",
|
||||
"anthropic == 0.26.1",
|
||||
"docx2txt == 0.8",
|
||||
"google-generativeai == 0.7.2"
|
||||
"google-generativeai == 0.8.3"
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from threading import Thread
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Thread
|
||||
from typing import Any, Iterator, List, Optional, Union
|
||||
|
@ -265,8 +266,14 @@ def send_message_to_model_offline(
|
|||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
response = offline_chat_model.create_chat_completion(
|
||||
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
|
||||
messages_dict,
|
||||
stop=stop,
|
||||
stream=streaming,
|
||||
temperature=temperature,
|
||||
response_format={"type": response_type},
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
if streaming:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import os
|
||||
from threading import Thread
|
||||
from typing import Dict
|
||||
|
||||
|
@ -60,6 +61,9 @@ def completion_with_backoff(
|
|||
model_kwargs.pop("stop", None)
|
||||
model_kwargs.pop("response_format", None)
|
||||
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
chat = client.chat.completions.create(
|
||||
stream=stream,
|
||||
messages=formatted_messages, # type: ignore
|
||||
|
@ -157,6 +161,9 @@ def llm_thread(
|
|||
model_kwargs.pop("stop", None)
|
||||
model_kwargs.pop("response_format", None)
|
||||
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
chat = client.chat.completions.create(
|
||||
stream=stream,
|
||||
messages=formatted_messages,
|
||||
|
|
Loading…
Reference in a new issue