mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-12-22 04:08:08 +00:00
Use chat model name var name consistently and type openai chat utils
This commit is contained in:
parent
4915be0301
commit
b0abec39d5
5 changed files with 36 additions and 30 deletions
|
@ -105,7 +105,7 @@ def extract_questions_offline(
|
||||||
response = send_message_to_model_offline(
|
response = send_message_to_model_offline(
|
||||||
messages,
|
messages,
|
||||||
loaded_model=offline_chat_model,
|
loaded_model=offline_chat_model,
|
||||||
model=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
|
@ -154,7 +154,7 @@ def converse_offline(
|
||||||
online_results={},
|
online_results={},
|
||||||
code_results={},
|
code_results={},
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||||
loaded_model: Union[Any, None] = None,
|
loaded_model: Union[Any, None] = None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_commands=[ConversationCommand.Default],
|
conversation_commands=[ConversationCommand.Default],
|
||||||
|
@ -174,8 +174,8 @@ def converse_offline(
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
|
||||||
tracer["chat_model"] = model
|
tracer["chat_model"] = model_name
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
|
|
||||||
if agent and agent.personality:
|
if agent and agent.personality:
|
||||||
|
@ -228,7 +228,7 @@ def converse_offline(
|
||||||
system_prompt,
|
system_prompt,
|
||||||
conversation_log,
|
conversation_log,
|
||||||
context_message=context_message,
|
context_message=context_message,
|
||||||
model_name=model,
|
model_name=model_name,
|
||||||
loaded_model=offline_chat_model,
|
loaded_model=offline_chat_model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
|
@ -239,7 +239,7 @@ def converse_offline(
|
||||||
program_execution_context=additional_context,
|
program_execution_context=additional_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
||||||
|
|
||||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
||||||
|
@ -273,7 +273,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
|
||||||
def send_message_to_model_offline(
|
def send_message_to_model_offline(
|
||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
loaded_model=None,
|
loaded_model=None,
|
||||||
model="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
model_name="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
stop=[],
|
stop=[],
|
||||||
|
@ -282,7 +282,7 @@ def send_message_to_model_offline(
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
|
||||||
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||||
response = offline_chat_model.create_chat_completion(
|
response = offline_chat_model.create_chat_completion(
|
||||||
|
@ -301,7 +301,7 @@ def send_message_to_model_offline(
|
||||||
|
|
||||||
# Save conversation trace for non-streaming responses
|
# Save conversation trace for non-streaming responses
|
||||||
# Streamed responses need to be saved by the calling function
|
# Streamed responses need to be saved by the calling function
|
||||||
tracer["chat_model"] = model
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, response_text, tracer)
|
commit_conversation_trace(messages, response_text, tracer)
|
||||||
|
|
|
@ -128,7 +128,7 @@ def send_message_to_model(
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
return completion_with_backoff(
|
return completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model_name=model,
|
||||||
openai_api_key=api_key,
|
openai_api_key=api_key,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
|
|
|
@ -40,7 +40,13 @@ openai_clients: Dict[str, openai.OpenAI] = {}
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def completion_with_backoff(
|
def completion_with_backoff(
|
||||||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
|
messages,
|
||||||
|
model_name: str,
|
||||||
|
temperature=0,
|
||||||
|
openai_api_key=None,
|
||||||
|
api_base_url=None,
|
||||||
|
model_kwargs: dict = {},
|
||||||
|
tracer: dict = {},
|
||||||
) -> str:
|
) -> str:
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
client: openai.OpenAI | None = openai_clients.get(client_key)
|
||||||
|
@ -56,7 +62,7 @@ def completion_with_backoff(
|
||||||
|
|
||||||
# Update request parameters for compatability with o1 model series
|
# Update request parameters for compatability with o1 model series
|
||||||
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
||||||
if model.startswith("o1"):
|
if model_name.startswith("o1"):
|
||||||
temperature = 1
|
temperature = 1
|
||||||
model_kwargs.pop("stop", None)
|
model_kwargs.pop("stop", None)
|
||||||
model_kwargs.pop("response_format", None)
|
model_kwargs.pop("response_format", None)
|
||||||
|
@ -66,12 +72,12 @@ def completion_with_backoff(
|
||||||
|
|
||||||
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
||||||
messages=formatted_messages, # type: ignore
|
messages=formatted_messages, # type: ignore
|
||||||
model=model, # type: ignore
|
model=model_name, # type: ignore
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stream_options={"include_usage": True} if stream else {},
|
stream_options={"include_usage": True} if stream else {},
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
**(model_kwargs or dict()),
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
|
@ -91,10 +97,10 @@ def completion_with_backoff(
|
||||||
# Calculate cost of chat
|
# Calculate cost of chat
|
||||||
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||||
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||||
tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage"))
|
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
|
@ -139,11 +145,11 @@ def chat_completion_with_backoff(
|
||||||
def llm_thread(
|
def llm_thread(
|
||||||
g,
|
g,
|
||||||
messages,
|
messages,
|
||||||
model_name,
|
model_name: str,
|
||||||
temperature,
|
temperature,
|
||||||
openai_api_key=None,
|
openai_api_key=None,
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
model_kwargs=None,
|
model_kwargs: dict = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
@ -177,7 +183,7 @@ def llm_thread(
|
||||||
stream_options={"include_usage": True} if stream else {},
|
stream_options={"include_usage": True} if stream else {},
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
**(model_kwargs or dict()),
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
|
|
|
@ -985,7 +985,7 @@ async def send_message_to_model_wrapper(
|
||||||
return send_message_to_model_offline(
|
return send_message_to_model_offline(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
model=chat_model_name,
|
model_name=chat_model_name,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
@ -1101,7 +1101,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
return send_message_to_model_offline(
|
return send_message_to_model_offline(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
model=chat_model_name,
|
model_name=chat_model_name,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
@ -1251,7 +1251,7 @@ def generate_chat_response(
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
model=chat_model.name,
|
model_name=chat_model.name,
|
||||||
max_prompt_size=chat_model.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
tokenizer_name=chat_model.tokenizer,
|
tokenizer_name=chat_model.tokenizer,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
|
|
|
@ -204,10 +204,10 @@ def initialization(interactive: bool = True):
|
||||||
ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
|
ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
|
||||||
|
|
||||||
if interactive:
|
if interactive:
|
||||||
chat_model_names = input(
|
user_chat_models = input(
|
||||||
f"Enter the {provider_name} chat models you want to use (default: {','.join(default_chat_models)}): "
|
f"Enter the {provider_name} chat models you want to use (default: {','.join(default_chat_models)}): "
|
||||||
)
|
)
|
||||||
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_chat_models
|
chat_models = user_chat_models.split(",") if user_chat_models != "" else default_chat_models
|
||||||
chat_models = [model.strip() for model in chat_models]
|
chat_models = [model.strip() for model in chat_models]
|
||||||
else:
|
else:
|
||||||
chat_models = default_chat_models
|
chat_models = default_chat_models
|
||||||
|
@ -255,14 +255,14 @@ def initialization(interactive: bool = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add new models
|
# Add new models
|
||||||
for model in available_models:
|
for model_name in available_models:
|
||||||
if not existing_models.filter(name=model).exists():
|
if not existing_models.filter(name=model_name).exists():
|
||||||
ChatModel.objects.create(
|
ChatModel.objects.create(
|
||||||
name=model,
|
name=model_name,
|
||||||
model_type=ChatModel.ModelType.OPENAI,
|
model_type=ChatModel.ModelType.OPENAI,
|
||||||
max_prompt_size=model_to_prompt_size.get(model),
|
max_prompt_size=model_to_prompt_size.get(model_name),
|
||||||
vision_enabled=model in default_openai_chat_models,
|
vision_enabled=model_name in default_openai_chat_models,
|
||||||
tokenizer=model_to_tokenizer.get(model),
|
tokenizer=model_to_tokenizer.get(model_name),
|
||||||
ai_model_api=config,
|
ai_model_api=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue