mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Support setting seed for reproducible LLM response generation
Anthropic models do not support seed. But offline, gemini and openai models do. Use these to debug and test Khoj via KHOJ_LLM_SEED env var
This commit is contained in:
parent
d44e68ba01
commit
b3a63017b5
4 changed files with 17 additions and 2 deletions
|
@ -87,7 +87,7 @@ dependencies = [
|
||||||
"django_apscheduler == 0.6.2",
|
"django_apscheduler == 0.6.2",
|
||||||
"anthropic == 0.26.1",
|
"anthropic == 0.26.1",
|
||||||
"docx2txt == 0.8",
|
"docx2txt == 0.8",
|
||||||
"google-generativeai == 0.7.2"
|
"google-generativeai == 0.8.3"
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Iterator, List, Optional, Union
|
from typing import Any, Iterator, List, Optional, Union
|
||||||
|
@ -265,8 +266,14 @@ def send_message_to_model_offline(
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
response = offline_chat_model.create_chat_completion(
|
response = offline_chat_model.create_chat_completion(
|
||||||
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
|
messages_dict,
|
||||||
|
stop=stop,
|
||||||
|
stream=streaming,
|
||||||
|
temperature=temperature,
|
||||||
|
response_format={"type": response_type},
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
@ -60,6 +61,9 @@ def completion_with_backoff(
|
||||||
model_kwargs.pop("stop", None)
|
model_kwargs.pop("stop", None)
|
||||||
model_kwargs.pop("response_format", None)
|
model_kwargs.pop("response_format", None)
|
||||||
|
|
||||||
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
chat = client.chat.completions.create(
|
chat = client.chat.completions.create(
|
||||||
stream=stream,
|
stream=stream,
|
||||||
messages=formatted_messages, # type: ignore
|
messages=formatted_messages, # type: ignore
|
||||||
|
@ -157,6 +161,9 @@ def llm_thread(
|
||||||
model_kwargs.pop("stop", None)
|
model_kwargs.pop("stop", None)
|
||||||
model_kwargs.pop("response_format", None)
|
model_kwargs.pop("response_format", None)
|
||||||
|
|
||||||
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
chat = client.chat.completions.create(
|
chat = client.chat.completions.create(
|
||||||
stream=stream,
|
stream=stream,
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
|
|
Loading…
Reference in a new issue