mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Add checksums to verify the correct model is downloaded as expected (#405)
* Add checksums to verify the correct model is downloaded as expected - This should help debug issues related to corrupted model download - If download fails, let the application continue * If the model is not download as expected, add some indicators in the settings UI * Add exc_info to error log if/when download fails for llamav2 model * Simplify checksum checking logic, update key name in model state for web client
This commit is contained in:
parent
6aa998e047
commit
0baed742e4
4 changed files with 41 additions and 8 deletions
|
@ -191,8 +191,8 @@
|
|||
<h3 class="card-title">
|
||||
Chat
|
||||
{% if current_config.processor and current_config.processor.conversation.openai %}
|
||||
{% if current_model_state.conversation == False %}
|
||||
<img id="misconfigured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||
{% if current_model_state.conversation_openai == False %}
|
||||
<img id="misconfigured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The OpenAI configuration did not work as expected.">
|
||||
{% else %}
|
||||
<img id="configured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
{% endif %}
|
||||
|
@ -225,7 +225,10 @@
|
|||
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||
<h3 class="card-title">
|
||||
Offline Chat
|
||||
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and not current_model_state.conversation_gpt4all %}
|
||||
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
|
||||
{% endif %}
|
||||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
import logging
|
||||
import requests
|
||||
import hashlib
|
||||
|
||||
from gpt4all import GPT4All
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -8,6 +10,16 @@ from khoj.processor.conversation.gpt4all import model_metadata
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"}
|
||||
|
||||
|
||||
def get_md5_checksum(filename: str):
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def download_model(model_name: str):
|
||||
url = model_metadata.model_name_to_url.get(model_name)
|
||||
|
@ -33,18 +45,26 @@ def download_model(model_name: str):
|
|||
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
|
||||
unit_divisor=1024, # is used when unit_scale is true
|
||||
total=total_size, # the total iteration.
|
||||
desc=filename.split("/")[-1], # prefix to be displayed on progress bar.
|
||||
desc=model_name, # prefix to be displayed on progress bar.
|
||||
) as progress_bar:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
# Verify the checksum
|
||||
if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename):
|
||||
logger.error(
|
||||
f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available."
|
||||
)
|
||||
os.remove(tmp_filename)
|
||||
raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.")
|
||||
|
||||
# Move the tmp file to the actual file
|
||||
os.rename(tmp_filename, filename)
|
||||
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
|
||||
return GPT4All(model_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}")
|
||||
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True)
|
||||
# Remove the tmp file if it exists
|
||||
if os.path.exists(tmp_filename):
|
||||
os.remove(tmp_filename)
|
||||
|
|
|
@ -47,7 +47,7 @@ if not state.demo:
|
|||
"image": False,
|
||||
"github": False,
|
||||
"notion": False,
|
||||
"conversation": False,
|
||||
"enable_offline_model": False,
|
||||
}
|
||||
|
||||
if state.content_index:
|
||||
|
@ -65,7 +65,8 @@ if not state.demo:
|
|||
if state.processor_config:
|
||||
successfully_configured.update(
|
||||
{
|
||||
"conversation": state.processor_config.conversation is not None,
|
||||
"conversation_openai": state.processor_config.conversation.openai_model is not None,
|
||||
"conversation_gpt4all": state.processor_config.conversation.gpt4all_model.loaded_model is not None,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
from __future__ import annotations # to avoid quoting type hints
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
|
||||
|
@ -10,6 +12,8 @@ from khoj.processor.conversation.gpt4all.utils import download_model
|
|||
# External Packages
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Internal Packages
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
@ -95,7 +99,12 @@ class ConversationProcessorConfigModel:
|
|||
self.meta_log: dict = {}
|
||||
|
||||
if self.enable_offline_chat:
|
||||
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
|
||||
try:
|
||||
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
|
||||
except ValueError as e:
|
||||
self.gpt4all_model.loaded_model = None
|
||||
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
||||
self.enable_offline_chat = False
|
||||
else:
|
||||
self.gpt4all_model.loaded_model = None
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue