Pass max context length to fix using updated GPT4All.list_gpu method

It's signature was updated in GPT4All 2.1.0 pypi release.

Resolves #610
This commit is contained in:
Debanjum Singh Solanky 2024-01-16 12:23:45 +05:30
parent 50575b749b
commit d74f8e03d3
2 changed files with 7 additions and 5 deletions

View file

@ -62,8 +62,8 @@ dependencies = [
"pymupdf >= 1.23.5",
"django == 4.2.7",
"authlib == 1.2.1",
"gpt4all >= 2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
"gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'",
"gpt4all >= 2.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
"gpt4all >= 2.1.0; platform_system == 'Windows' or platform_system == 'Darwin'",
"itsdangerous == 2.1.2",
"httpx == 0.25.0",
"pgvector == 0.2.4",

View file

@ -21,9 +21,11 @@ def download_model(model_name: str):
# Try load chat model to GPU if:
# 1. Loading chat model to GPU isn't disabled via CLI and
# 2. Machine has GPU
# 3. GPU has enough free memory to load the chat model
# 3. GPU has enough free memory to load the chat model with max context length of 4096
device = (
"gpu" if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"]) else "cpu"
"gpu"
if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"], 4096)
else "cpu"
)
except ValueError:
device = "cpu"
@ -35,7 +37,7 @@ def download_model(model_name: str):
raise e
# Now load the downloaded chat model onto appropriate device
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)
chat_model = gpt4all.GPT4All(model_name=model_name, n_ctx=4096, device=device, allow_download=False)
logger.debug(f"Loaded chat model to {device.upper()}.")
return chat_model