From d74f8e03d3efbd9843149bcbb81aa65737821f3b Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 16 Jan 2024 12:23:45 +0530 Subject: [PATCH] Pass max context length to fix using updated GPT4All.list_gpu method It's signature was updated in GPT4All 2.1.0 pypi release. Resolves #610 --- pyproject.toml | 4 ++-- src/khoj/processor/conversation/offline/utils.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 693415d4..fbfa7dac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,8 +62,8 @@ dependencies = [ "pymupdf >= 1.23.5", "django == 4.2.7", "authlib == 1.2.1", - "gpt4all >= 2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'", - "gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'", + "gpt4all >= 2.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'", + "gpt4all >= 2.1.0; platform_system == 'Windows' or platform_system == 'Darwin'", "itsdangerous == 2.1.2", "httpx == 0.25.0", "pgvector == 0.2.4", diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index 3a1862f7..9a2223c6 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -21,9 +21,11 @@ def download_model(model_name: str): # Try load chat model to GPU if: # 1. Loading chat model to GPU isn't disabled via CLI and # 2. Machine has GPU - # 3. GPU has enough free memory to load the chat model + # 3. GPU has enough free memory to load the chat model with max context length of 4096 device = ( - "gpu" if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"]) else "cpu" + "gpu" + if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"], 4096) + else "cpu" ) except ValueError: device = "cpu" @@ -35,7 +37,7 @@ def download_model(model_name: str): raise e # Now load the downloaded chat model onto appropriate device - chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False) + chat_model = gpt4all.GPT4All(model_name=model_name, n_ctx=4096, device=device, allow_download=False) logger.debug(f"Loaded chat model to {device.upper()}.") return chat_model