From 13b16a4364abf6114056a96f0a52c8e63736e738 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 3 Oct 2023 16:29:46 -0700 Subject: [PATCH] Use default Llama 2 supported by GPT4All Remove custom logic to download custom Llama 2 model. This was added as GPT4All didn't support Llama 2 when it was added to Khoj --- .../conversation/gpt4all/chat_model.py | 4 +- .../conversation/gpt4all/model_metadata.py | 3 - .../processor/conversation/gpt4all/utils.py | 71 +------------------ src/khoj/processor/conversation/utils.py | 4 +- src/khoj/utils/config.py | 2 +- tests/test_gpt4all_chat_actors.py | 2 +- 6 files changed, 7 insertions(+), 79 deletions(-) delete mode 100644 src/khoj/processor/conversation/gpt4all/model_metadata.py diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index 9bc9ea52..d713831a 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) def extract_questions_offline( text: str, - model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin", + model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin", loaded_model: Union[Any, None] = None, conversation_log={}, use_history: bool = True, @@ -123,7 +123,7 @@ def converse_offline( references, user_query, conversation_log={}, - model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin", + model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin", loaded_model: Union[Any, None] = None, completion_func=None, conversation_command=ConversationCommand.Default, diff --git a/src/khoj/processor/conversation/gpt4all/model_metadata.py b/src/khoj/processor/conversation/gpt4all/model_metadata.py deleted file mode 100644 index 065e3720..00000000 --- a/src/khoj/processor/conversation/gpt4all/model_metadata.py +++ /dev/null @@ -1,3 +0,0 @@ -model_name_to_url = { - "llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin" -} diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py index 4042fbe2..585df6a6 100644 --- a/src/khoj/processor/conversation/gpt4all/utils.py +++ b/src/khoj/processor/conversation/gpt4all/utils.py @@ -1,24 +1,8 @@ -import os import logging -import requests -import hashlib -from tqdm import tqdm - -from khoj.processor.conversation.gpt4all import model_metadata logger = logging.getLogger(__name__) -expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"} - - -def get_md5_checksum(filename: str): - hash_md5 = hashlib.md5() - with open(filename, "rb") as f: - for chunk in iter(lambda: f.read(8192), b""): - hash_md5.update(chunk) - return hash_md5.hexdigest() - def download_model(model_name: str): try: @@ -27,57 +11,4 @@ def download_model(model_name: str): logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") raise e - url = model_metadata.model_name_to_url.get(model_name) - model_path = os.path.expanduser(f"~/.cache/gpt4all/") - if not url: - logger.debug(f"Model {model_name} not found in model metadata. Skipping download.") - return GPT4All(model_name=model_name, model_path=model_path) - - filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}") - if os.path.exists(filename): - # Check if the user is connected to the internet - try: - requests.get("https://www.google.com/", timeout=5) - except: - logger.debug("User is offline. Disabling allowed download flag") - return GPT4All(model_name=model_name, model_path=model_path, allow_download=False) - return GPT4All(model_name=model_name, model_path=model_path) - - # Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file - tmp_filename = filename + ".tmp" - - try: - os.makedirs(os.path.dirname(tmp_filename), exist_ok=True) - logger.debug(f"Downloading model {model_name} from {url} to {filename}...") - with requests.get(url, stream=True) as r: - r.raise_for_status() - total_size = int(r.headers.get("content-length", 0)) - with open(tmp_filename, "wb") as f, tqdm( - unit="B", # unit string to be displayed. - unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc. - unit_divisor=1024, # is used when unit_scale is true - total=total_size, # the total iteration. - desc=model_name, # prefix to be displayed on progress bar. - ) as progress_bar: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - progress_bar.update(len(chunk)) - - # Verify the checksum - if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename): - logger.error( - f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available." - ) - os.remove(tmp_filename) - raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.") - - # Move the tmp file to the actual file - os.rename(tmp_filename, filename) - logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}") - return GPT4All(model_name) - except Exception as e: - logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True) - # Remove the tmp file if it exists - if os.path.exists(tmp_filename): - os.remove(tmp_filename) - return None + return GPT4All(model_name=model_name) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 4a92c367..ece526c2 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -17,10 +17,10 @@ logger = logging.getLogger(__name__) max_prompt_size = { "gpt-3.5-turbo": 4096, "gpt-4": 8192, - "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548, + "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548, "gpt-3.5-turbo-16k": 15000, } -tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"} +tokenizer = {"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer"} class ThreadedGenerator: diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index a6532346..f06d4c69 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -84,7 +84,7 @@ class SearchModels: @dataclass class GPT4AllProcessorConfig: - chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin" + chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin" loaded_model: Union[Any, None] = None diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py index d7904ff8..32ee4020 100644 --- a/tests/test_gpt4all_chat_actors.py +++ b/tests/test_gpt4all_chat_actors.py @@ -24,7 +24,7 @@ from khoj.processor.conversation.gpt4all.utils import download_model from khoj.processor.conversation.utils import message_to_log -MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin" +MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_0.bin" @pytest.fixture(scope="session")