Use default Llama 2 supported by GPT4All

Remove custom logic to download custom Llama 2 model.
This was added as GPT4All didn't support Llama 2 when it was added to Khoj
This commit is contained in:
Debanjum Singh Solanky 2023-10-03 16:29:46 -07:00
parent 4a5ed7f06c
commit 13b16a4364
6 changed files with 7 additions and 79 deletions

View file

@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
def extract_questions_offline( def extract_questions_offline(
text: str, text: str,
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin", model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
loaded_model: Union[Any, None] = None, loaded_model: Union[Any, None] = None,
conversation_log={}, conversation_log={},
use_history: bool = True, use_history: bool = True,
@ -123,7 +123,7 @@ def converse_offline(
references, references,
user_query, user_query,
conversation_log={}, conversation_log={},
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin", model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
loaded_model: Union[Any, None] = None, loaded_model: Union[Any, None] = None,
completion_func=None, completion_func=None,
conversation_command=ConversationCommand.Default, conversation_command=ConversationCommand.Default,

View file

@ -1,3 +0,0 @@
model_name_to_url = {
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
}

View file

@ -1,24 +1,8 @@
import os
import logging import logging
import requests
import hashlib
from tqdm import tqdm
from khoj.processor.conversation.gpt4all import model_metadata
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"}
def get_md5_checksum(filename: str):
hash_md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def download_model(model_name: str): def download_model(model_name: str):
try: try:
@ -27,57 +11,4 @@ def download_model(model_name: str):
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e raise e
url = model_metadata.model_name_to_url.get(model_name) return GPT4All(model_name=model_name)
model_path = os.path.expanduser(f"~/.cache/gpt4all/")
if not url:
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
return GPT4All(model_name=model_name, model_path=model_path)
filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}")
if os.path.exists(filename):
# Check if the user is connected to the internet
try:
requests.get("https://www.google.com/", timeout=5)
except:
logger.debug("User is offline. Disabling allowed download flag")
return GPT4All(model_name=model_name, model_path=model_path, allow_download=False)
return GPT4All(model_name=model_name, model_path=model_path)
# Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file
tmp_filename = filename + ".tmp"
try:
os.makedirs(os.path.dirname(tmp_filename), exist_ok=True)
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
with requests.get(url, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0))
with open(tmp_filename, "wb") as f, tqdm(
unit="B", # unit string to be displayed.
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
unit_divisor=1024, # is used when unit_scale is true
total=total_size, # the total iteration.
desc=model_name, # prefix to be displayed on progress bar.
) as progress_bar:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
progress_bar.update(len(chunk))
# Verify the checksum
if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename):
logger.error(
f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available."
)
os.remove(tmp_filename)
raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.")
# Move the tmp file to the actual file
os.rename(tmp_filename, filename)
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
return GPT4All(model_name)
except Exception as e:
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True)
# Remove the tmp file if it exists
if os.path.exists(tmp_filename):
os.remove(tmp_filename)
return None

View file

@ -17,10 +17,10 @@ logger = logging.getLogger(__name__)
max_prompt_size = { max_prompt_size = {
"gpt-3.5-turbo": 4096, "gpt-3.5-turbo": 4096,
"gpt-4": 8192, "gpt-4": 8192,
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548, "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
"gpt-3.5-turbo-16k": 15000, "gpt-3.5-turbo-16k": 15000,
} }
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"} tokenizer = {"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer"}
class ThreadedGenerator: class ThreadedGenerator:

View file

@ -84,7 +84,7 @@ class SearchModels:
@dataclass @dataclass
class GPT4AllProcessorConfig: class GPT4AllProcessorConfig:
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin" chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin"
loaded_model: Union[Any, None] = None loaded_model: Union[Any, None] = None

View file

@ -24,7 +24,7 @@ from khoj.processor.conversation.gpt4all.utils import download_model
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin" MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_0.bin"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")