From b0abec39d509c8e0ffc9add9c9facb4d785c178c Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 12 Dec 2024 14:05:53 -0800 Subject: [PATCH] Use chat model name var name consistently and type openai chat utils --- .../conversation/offline/chat_model.py | 18 +++++++------- src/khoj/processor/conversation/openai/gpt.py | 2 +- .../processor/conversation/openai/utils.py | 24 ++++++++++++------- src/khoj/routers/helpers.py | 6 ++--- src/khoj/utils/initialization.py | 16 ++++++------- 5 files changed, 36 insertions(+), 30 deletions(-) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 7e07dc20..5ce45bac 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -105,7 +105,7 @@ def extract_questions_offline( response = send_message_to_model_offline( messages, loaded_model=offline_chat_model, - model=model, + model_name=model, max_prompt_size=max_prompt_size, temperature=temperature, response_type="json_object", @@ -154,7 +154,7 @@ def converse_offline( online_results={}, code_results={}, conversation_log={}, - model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", + model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", loaded_model: Union[Any, None] = None, completion_func=None, conversation_commands=[ConversationCommand.Default], @@ -174,8 +174,8 @@ def converse_offline( """ # Initialize Variables assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) - tracer["chat_model"] = model + offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size) + tracer["chat_model"] = model_name current_date = datetime.now() if agent and agent.personality: @@ -228,7 +228,7 @@ def converse_offline( system_prompt, conversation_log, context_message=context_message, - model_name=model, + model_name=model_name, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, @@ -239,7 +239,7 @@ def converse_offline( program_execution_context=additional_context, ) - logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}") + logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}") g = ThreadedGenerator(references, online_results, completion_func=completion_func) t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer)) @@ -273,7 +273,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int def send_message_to_model_offline( messages: List[ChatMessage], loaded_model=None, - model="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", + model_name="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", temperature: float = 0.2, streaming=False, stop=[], @@ -282,7 +282,7 @@ def send_message_to_model_offline( tracer: dict = {}, ): assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) + offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size) messages_dict = [{"role": message.role, "content": message.content} for message in messages] seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None response = offline_chat_model.create_chat_completion( @@ -301,7 +301,7 @@ def send_message_to_model_offline( # Save conversation trace for non-streaming responses # Streamed responses need to be saved by the calling function - tracer["chat_model"] = model + tracer["chat_model"] = model_name tracer["temperature"] = temperature if is_promptrace_enabled(): commit_conversation_trace(messages, response_text, tracer) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index dbe672a4..389f52ab 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -128,7 +128,7 @@ def send_message_to_model( # Get Response from GPT return completion_with_backoff( messages=messages, - model=model, + model_name=model, openai_api_key=api_key, temperature=temperature, api_base_url=api_base_url, diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 160af77c..84b0081b 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -40,7 +40,13 @@ openai_clients: Dict[str, openai.OpenAI] = {} reraise=True, ) def completion_with_backoff( - messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {} + messages, + model_name: str, + temperature=0, + openai_api_key=None, + api_base_url=None, + model_kwargs: dict = {}, + tracer: dict = {}, ) -> str: client_key = f"{openai_api_key}--{api_base_url}" client: openai.OpenAI | None = openai_clients.get(client_key) @@ -56,7 +62,7 @@ def completion_with_backoff( # Update request parameters for compatability with o1 model series # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations - if model.startswith("o1"): + if model_name.startswith("o1"): temperature = 1 model_kwargs.pop("stop", None) model_kwargs.pop("response_format", None) @@ -66,12 +72,12 @@ def completion_with_backoff( chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( messages=formatted_messages, # type: ignore - model=model, # type: ignore + model=model_name, # type: ignore stream=stream, stream_options={"include_usage": True} if stream else {}, temperature=temperature, timeout=20, - **(model_kwargs or dict()), + **model_kwargs, ) aggregated_response = "" @@ -91,10 +97,10 @@ def completion_with_backoff( # Calculate cost of chat input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0 - tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage")) + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) # Save conversation trace - tracer["chat_model"] = model + tracer["chat_model"] = model_name tracer["temperature"] = temperature if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) @@ -139,11 +145,11 @@ def chat_completion_with_backoff( def llm_thread( g, messages, - model_name, + model_name: str, temperature, openai_api_key=None, api_base_url=None, - model_kwargs=None, + model_kwargs: dict = {}, tracer: dict = {}, ): try: @@ -177,7 +183,7 @@ def llm_thread( stream_options={"include_usage": True} if stream else {}, temperature=temperature, timeout=20, - **(model_kwargs or dict()), + **model_kwargs, ) aggregated_response = "" diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 3869f478..ecd1f1e4 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -985,7 +985,7 @@ async def send_message_to_model_wrapper( return send_message_to_model_offline( messages=truncated_messages, loaded_model=loaded_model, - model=chat_model_name, + model_name=chat_model_name, max_prompt_size=max_tokens, streaming=False, response_type=response_type, @@ -1101,7 +1101,7 @@ def send_message_to_model_wrapper_sync( return send_message_to_model_offline( messages=truncated_messages, loaded_model=loaded_model, - model=chat_model_name, + model_name=chat_model_name, max_prompt_size=max_tokens, streaming=False, response_type=response_type, @@ -1251,7 +1251,7 @@ def generate_chat_response( conversation_log=meta_log, completion_func=partial_completion, conversation_commands=conversation_commands, - model=chat_model.name, + model_name=chat_model.name, max_prompt_size=chat_model.max_prompt_size, tokenizer_name=chat_model.tokenizer, location_data=location_data, diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 108c15a9..a4864dcc 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -204,10 +204,10 @@ def initialization(interactive: bool = True): ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url) if interactive: - chat_model_names = input( + user_chat_models = input( f"Enter the {provider_name} chat models you want to use (default: {','.join(default_chat_models)}): " ) - chat_models = chat_model_names.split(",") if chat_model_names != "" else default_chat_models + chat_models = user_chat_models.split(",") if user_chat_models != "" else default_chat_models chat_models = [model.strip() for model in chat_models] else: chat_models = default_chat_models @@ -255,14 +255,14 @@ def initialization(interactive: bool = True): ) # Add new models - for model in available_models: - if not existing_models.filter(name=model).exists(): + for model_name in available_models: + if not existing_models.filter(name=model_name).exists(): ChatModel.objects.create( - name=model, + name=model_name, model_type=ChatModel.ModelType.OPENAI, - max_prompt_size=model_to_prompt_size.get(model), - vision_enabled=model in default_openai_chat_models, - tokenizer=model_to_tokenizer.get(model), + max_prompt_size=model_to_prompt_size.get(model_name), + vision_enabled=model_name in default_openai_chat_models, + tokenizer=model_to_tokenizer.get(model_name), ai_model_api=config, )