diff --git a/pyproject.toml b/pyproject.toml index a52fc9b6..e6773b88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,8 +59,8 @@ dependencies = [ "bs4 >= 0.0.1", "anyio == 3.7.1", "pymupdf >= 1.23.3", - "gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'", - "gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'", + "gpt4all >= 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'", + "gpt4all >= 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'", ] dynamic = ["version"] diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index ece526c2..96c4c1c8 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -19,8 +19,12 @@ max_prompt_size = { "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548, "gpt-3.5-turbo-16k": 15000, + "default": 1600, +} +tokenizer = { + "llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer", + "default": "hf-internal-testing/llama-tokenizer", } -tokenizer = {"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer"} class ThreadedGenerator: @@ -105,7 +109,7 @@ def generate_chatml_messages_with_context( messages = user_chatml_message + rest_backnforths + system_chatml_message # Truncate oldest messages from conversation history until under max supported prompt size by model - messages = truncate_messages(messages, max_prompt_size[model_name], model_name) + messages = truncate_messages(messages, max_prompt_size.get(model_name, max_prompt_size["default"]), model_name) # Return message in chronological order return messages[::-1] @@ -116,8 +120,10 @@ def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) if "llama" in model_name: encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name]) - else: + elif "gpt" in model_name: encoder = tiktoken.encoding_for_model(model_name) + else: + encoder = LlamaTokenizerFast.from_pretrained(tokenizer["default"]) system_message = messages.pop() system_message_tokens = len(encoder.encode(system_message.content)) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 267af330..3898d1b8 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -122,6 +122,7 @@ def generate_chat_response( conversation_log=meta_log, completion_func=partial_completion, conversation_command=conversation_command, + model=state.processor_config.conversation.gpt4all_model.chat_model, ) elif state.processor_config.conversation.openai_model: