mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-12-02 20:03:01 +01:00
Use default Llama 2 supported by GPT4All
Remove custom logic to download custom Llama 2 model. This was added as GPT4All didn't support Llama 2 when it was added to Khoj
This commit is contained in:
parent
4a5ed7f06c
commit
13b16a4364
6 changed files with 7 additions and 79 deletions
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def extract_questions_offline(
|
def extract_questions_offline(
|
||||||
text: str,
|
text: str,
|
||||||
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
|
model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
|
||||||
loaded_model: Union[Any, None] = None,
|
loaded_model: Union[Any, None] = None,
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
use_history: bool = True,
|
use_history: bool = True,
|
||||||
|
@ -123,7 +123,7 @@ def converse_offline(
|
||||||
references,
|
references,
|
||||||
user_query,
|
user_query,
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
|
model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
|
||||||
loaded_model: Union[Any, None] = None,
|
loaded_model: Union[Any, None] = None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_command=ConversationCommand.Default,
|
conversation_command=ConversationCommand.Default,
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
model_name_to_url = {
|
|
||||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
|
||||||
}
|
|
|
@ -1,24 +1,8 @@
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
import requests
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from khoj.processor.conversation.gpt4all import model_metadata
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"}
|
|
||||||
|
|
||||||
|
|
||||||
def get_md5_checksum(filename: str):
|
|
||||||
hash_md5 = hashlib.md5()
|
|
||||||
with open(filename, "rb") as f:
|
|
||||||
for chunk in iter(lambda: f.read(8192), b""):
|
|
||||||
hash_md5.update(chunk)
|
|
||||||
return hash_md5.hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def download_model(model_name: str):
|
def download_model(model_name: str):
|
||||||
try:
|
try:
|
||||||
|
@ -27,57 +11,4 @@ def download_model(model_name: str):
|
||||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
url = model_metadata.model_name_to_url.get(model_name)
|
return GPT4All(model_name=model_name)
|
||||||
model_path = os.path.expanduser(f"~/.cache/gpt4all/")
|
|
||||||
if not url:
|
|
||||||
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
|
|
||||||
return GPT4All(model_name=model_name, model_path=model_path)
|
|
||||||
|
|
||||||
filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}")
|
|
||||||
if os.path.exists(filename):
|
|
||||||
# Check if the user is connected to the internet
|
|
||||||
try:
|
|
||||||
requests.get("https://www.google.com/", timeout=5)
|
|
||||||
except:
|
|
||||||
logger.debug("User is offline. Disabling allowed download flag")
|
|
||||||
return GPT4All(model_name=model_name, model_path=model_path, allow_download=False)
|
|
||||||
return GPT4All(model_name=model_name, model_path=model_path)
|
|
||||||
|
|
||||||
# Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file
|
|
||||||
tmp_filename = filename + ".tmp"
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.makedirs(os.path.dirname(tmp_filename), exist_ok=True)
|
|
||||||
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
|
|
||||||
with requests.get(url, stream=True) as r:
|
|
||||||
r.raise_for_status()
|
|
||||||
total_size = int(r.headers.get("content-length", 0))
|
|
||||||
with open(tmp_filename, "wb") as f, tqdm(
|
|
||||||
unit="B", # unit string to be displayed.
|
|
||||||
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
|
|
||||||
unit_divisor=1024, # is used when unit_scale is true
|
|
||||||
total=total_size, # the total iteration.
|
|
||||||
desc=model_name, # prefix to be displayed on progress bar.
|
|
||||||
) as progress_bar:
|
|
||||||
for chunk in r.iter_content(chunk_size=8192):
|
|
||||||
f.write(chunk)
|
|
||||||
progress_bar.update(len(chunk))
|
|
||||||
|
|
||||||
# Verify the checksum
|
|
||||||
if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename):
|
|
||||||
logger.error(
|
|
||||||
f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available."
|
|
||||||
)
|
|
||||||
os.remove(tmp_filename)
|
|
||||||
raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.")
|
|
||||||
|
|
||||||
# Move the tmp file to the actual file
|
|
||||||
os.rename(tmp_filename, filename)
|
|
||||||
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
|
|
||||||
return GPT4All(model_name)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True)
|
|
||||||
# Remove the tmp file if it exists
|
|
||||||
if os.path.exists(tmp_filename):
|
|
||||||
os.remove(tmp_filename)
|
|
||||||
return None
|
|
||||||
|
|
|
@ -17,10 +17,10 @@ logger = logging.getLogger(__name__)
|
||||||
max_prompt_size = {
|
max_prompt_size = {
|
||||||
"gpt-3.5-turbo": 4096,
|
"gpt-3.5-turbo": 4096,
|
||||||
"gpt-4": 8192,
|
"gpt-4": 8192,
|
||||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548,
|
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
|
||||||
"gpt-3.5-turbo-16k": 15000,
|
"gpt-3.5-turbo-16k": 15000,
|
||||||
}
|
}
|
||||||
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
|
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer"}
|
||||||
|
|
||||||
|
|
||||||
class ThreadedGenerator:
|
class ThreadedGenerator:
|
||||||
|
|
|
@ -84,7 +84,7 @@ class SearchModels:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPT4AllProcessorConfig:
|
class GPT4AllProcessorConfig:
|
||||||
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin"
|
||||||
loaded_model: Union[Any, None] = None
|
loaded_model: Union[Any, None] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ from khoj.processor.conversation.gpt4all.utils import download_model
|
||||||
|
|
||||||
from khoj.processor.conversation.utils import message_to_log
|
from khoj.processor.conversation.utils import message_to_log
|
||||||
|
|
||||||
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_0.bin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
Loading…
Reference in a new issue