Add checksums to verify the correct model is downloaded as expected (#405)

* Add checksums to verify the correct model is downloaded as expected
- This should help debug issues related to corrupted model download
- If download fails, let the application continue
* If the model is not download as expected, add some indicators in the settings UI
* Add exc_info to error log if/when download fails for llamav2 model
* Simplify checksum checking logic, update key name in model state for web client
This commit is contained in:
sabaimran 2023-08-03 06:26:52 +00:00 committed by GitHub
parent 6aa998e047
commit 0baed742e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 8 deletions

View file

@ -191,8 +191,8 @@
<h3 class="card-title">
Chat
{% if current_config.processor and current_config.processor.conversation.openai %}
{% if current_model_state.conversation == False %}
<img id="misconfigured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
{% if current_model_state.conversation_openai == False %}
<img id="misconfigured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The OpenAI configuration did not work as expected.">
{% else %}
<img id="configured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% endif %}
@ -225,7 +225,10 @@
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<h3 class="card-title">
Offline Chat
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and not current_model_state.conversation_gpt4all %}
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
{% endif %}
</h3>
</div>
<div class="card-description-row">

View file

@ -1,6 +1,8 @@
import os
import logging
import requests
import hashlib
from gpt4all import GPT4All
from tqdm import tqdm
@ -8,6 +10,16 @@ from khoj.processor.conversation.gpt4all import model_metadata
logger = logging.getLogger(__name__)
expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"}
def get_md5_checksum(filename: str):
hash_md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def download_model(model_name: str):
url = model_metadata.model_name_to_url.get(model_name)
@ -33,18 +45,26 @@ def download_model(model_name: str):
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
unit_divisor=1024, # is used when unit_scale is true
total=total_size, # the total iteration.
desc=filename.split("/")[-1], # prefix to be displayed on progress bar.
desc=model_name, # prefix to be displayed on progress bar.
) as progress_bar:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
progress_bar.update(len(chunk))
# Verify the checksum
if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename):
logger.error(
f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available."
)
os.remove(tmp_filename)
raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.")
# Move the tmp file to the actual file
os.rename(tmp_filename, filename)
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
return GPT4All(model_name)
except Exception as e:
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}")
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True)
# Remove the tmp file if it exists
if os.path.exists(tmp_filename):
os.remove(tmp_filename)

View file

@ -47,7 +47,7 @@ if not state.demo:
"image": False,
"github": False,
"notion": False,
"conversation": False,
"enable_offline_model": False,
}
if state.content_index:
@ -65,7 +65,8 @@ if not state.demo:
if state.processor_config:
successfully_configured.update(
{
"conversation": state.processor_config.conversation is not None,
"conversation_openai": state.processor_config.conversation.openai_model is not None,
"conversation_gpt4all": state.processor_config.conversation.gpt4all_model.loaded_model is not None,
}
)

View file

@ -2,6 +2,8 @@
from __future__ import annotations # to avoid quoting type hints
from enum import Enum
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
@ -10,6 +12,8 @@ from khoj.processor.conversation.gpt4all.utils import download_model
# External Packages
import torch
logger = logging.getLogger(__name__)
# Internal Packages
if TYPE_CHECKING:
from sentence_transformers import CrossEncoder
@ -95,7 +99,12 @@ class ConversationProcessorConfigModel:
self.meta_log: dict = {}
if self.enable_offline_chat:
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
try:
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
except ValueError as e:
self.gpt4all_model.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
self.enable_offline_chat = False
else:
self.gpt4all_model.loaded_model = None