Use chat model name var name consistently and type openai chat utils

This commit is contained in:
Debanjum 2024-12-12 14:05:53 -08:00
parent 4915be0301
commit b0abec39d5
5 changed files with 36 additions and 30 deletions
src/khoj
processor/conversation
routers
utils

View file

@ -105,7 +105,7 @@ def extract_questions_offline(
response = send_message_to_model_offline(
messages,
loaded_model=offline_chat_model,
model=model,
model_name=model,
max_prompt_size=max_prompt_size,
temperature=temperature,
response_type="json_object",
@ -154,7 +154,7 @@ def converse_offline(
online_results={},
code_results={},
conversation_log={},
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
@ -174,8 +174,8 @@ def converse_offline(
"""
# Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
tracer["chat_model"] = model
offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
tracer["chat_model"] = model_name
current_date = datetime.now()
if agent and agent.personality:
@ -228,7 +228,7 @@ def converse_offline(
system_prompt,
conversation_log,
context_message=context_message,
model_name=model,
model_name=model_name,
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
@ -239,7 +239,7 @@ def converse_offline(
program_execution_context=additional_context,
)
logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}")
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
@ -273,7 +273,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
def send_message_to_model_offline(
messages: List[ChatMessage],
loaded_model=None,
model="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
model_name="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
temperature: float = 0.2,
streaming=False,
stop=[],
@ -282,7 +282,7 @@ def send_message_to_model_offline(
tracer: dict = {},
):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
response = offline_chat_model.create_chat_completion(
@ -301,7 +301,7 @@ def send_message_to_model_offline(
# Save conversation trace for non-streaming responses
# Streamed responses need to be saved by the calling function
tracer["chat_model"] = model
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, response_text, tracer)

View file

@ -128,7 +128,7 @@ def send_message_to_model(
# Get Response from GPT
return completion_with_backoff(
messages=messages,
model=model,
model_name=model,
openai_api_key=api_key,
temperature=temperature,
api_base_url=api_base_url,

View file

@ -40,7 +40,13 @@ openai_clients: Dict[str, openai.OpenAI] = {}
reraise=True,
)
def completion_with_backoff(
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
messages,
model_name: str,
temperature=0,
openai_api_key=None,
api_base_url=None,
model_kwargs: dict = {},
tracer: dict = {},
) -> str:
client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key)
@ -56,7 +62,7 @@ def completion_with_backoff(
# Update request parameters for compatability with o1 model series
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
if model.startswith("o1"):
if model_name.startswith("o1"):
temperature = 1
model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None)
@ -66,12 +72,12 @@ def completion_with_backoff(
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
messages=formatted_messages, # type: ignore
model=model, # type: ignore
model=model_name, # type: ignore
stream=stream,
stream_options={"include_usage": True} if stream else {},
temperature=temperature,
timeout=20,
**(model_kwargs or dict()),
**model_kwargs,
)
aggregated_response = ""
@ -91,10 +97,10 @@ def completion_with_backoff(
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage"))
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
@ -139,11 +145,11 @@ def chat_completion_with_backoff(
def llm_thread(
g,
messages,
model_name,
model_name: str,
temperature,
openai_api_key=None,
api_base_url=None,
model_kwargs=None,
model_kwargs: dict = {},
tracer: dict = {},
):
try:
@ -177,7 +183,7 @@ def llm_thread(
stream_options={"include_usage": True} if stream else {},
temperature=temperature,
timeout=20,
**(model_kwargs or dict()),
**model_kwargs,
)
aggregated_response = ""

View file

@ -985,7 +985,7 @@ async def send_message_to_model_wrapper(
return send_message_to_model_offline(
messages=truncated_messages,
loaded_model=loaded_model,
model=chat_model_name,
model_name=chat_model_name,
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
@ -1101,7 +1101,7 @@ def send_message_to_model_wrapper_sync(
return send_message_to_model_offline(
messages=truncated_messages,
loaded_model=loaded_model,
model=chat_model_name,
model_name=chat_model_name,
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
@ -1251,7 +1251,7 @@ def generate_chat_response(
conversation_log=meta_log,
completion_func=partial_completion,
conversation_commands=conversation_commands,
model=chat_model.name,
model_name=chat_model.name,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,

View file

@ -204,10 +204,10 @@ def initialization(interactive: bool = True):
ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
if interactive:
chat_model_names = input(
user_chat_models = input(
f"Enter the {provider_name} chat models you want to use (default: {','.join(default_chat_models)}): "
)
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_chat_models
chat_models = user_chat_models.split(",") if user_chat_models != "" else default_chat_models
chat_models = [model.strip() for model in chat_models]
else:
chat_models = default_chat_models
@ -255,14 +255,14 @@ def initialization(interactive: bool = True):
)
# Add new models
for model in available_models:
if not existing_models.filter(name=model).exists():
for model_name in available_models:
if not existing_models.filter(name=model_name).exists():
ChatModel.objects.create(
name=model,
name=model_name,
model_type=ChatModel.ModelType.OPENAI,
max_prompt_size=model_to_prompt_size.get(model),
vision_enabled=model in default_openai_chat_models,
tokenizer=model_to_tokenizer.get(model),
max_prompt_size=model_to_prompt_size.get(model_name),
vision_enabled=model_name in default_openai_chat_models,
tokenizer=model_to_tokenizer.get(model_name),
ai_model_api=config,
)