mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Add support for Llama 3 in Khoj offline mode
- Improve extract question prompts to explicitly request JSON list - Use llama-3 chat format if HF repo_id mentions llama-3. The llama-cpp-python logic for detecting when to use llama-3 chat format isn't robust enough currently
This commit is contained in:
parent
8e77b3dc82
commit
a2e4e4bede
2 changed files with 7 additions and 2 deletions
|
@ -64,7 +64,7 @@ dependencies = [
|
||||||
"pymupdf >= 1.23.5",
|
"pymupdf >= 1.23.5",
|
||||||
"django == 4.2.10",
|
"django == 4.2.10",
|
||||||
"authlib == 1.2.1",
|
"authlib == 1.2.1",
|
||||||
"llama-cpp-python == 0.2.56",
|
"llama-cpp-python == 0.2.64",
|
||||||
"itsdangerous == 2.1.2",
|
"itsdangerous == 2.1.2",
|
||||||
"httpx == 0.25.0",
|
"httpx == 0.25.0",
|
||||||
"pgvector == 0.2.4",
|
"pgvector == 0.2.4",
|
||||||
|
|
|
@ -2,6 +2,7 @@ import glob
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from huggingface_hub.constants import HF_HUB_CACHE
|
from huggingface_hub.constants import HF_HUB_CACHE
|
||||||
|
|
||||||
|
@ -14,12 +15,16 @@ logger = logging.getLogger(__name__)
|
||||||
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
|
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
|
||||||
# Initialize Model Parameters
|
# Initialize Model Parameters
|
||||||
# Use n_ctx=0 to get context size from the model
|
# Use n_ctx=0 to get context size from the model
|
||||||
kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False}
|
kwargs: Dict[str, Any] = {"n_threads": 4, "n_ctx": 0, "verbose": False}
|
||||||
|
|
||||||
# Decide whether to load model to GPU or CPU
|
# Decide whether to load model to GPU or CPU
|
||||||
device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu"
|
device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu"
|
||||||
kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0
|
kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0
|
||||||
|
|
||||||
|
# Add chat format if known
|
||||||
|
if "llama-3" in repo_id.lower():
|
||||||
|
kwargs["chat_format"] = "llama-3"
|
||||||
|
|
||||||
# Check if the model is already downloaded
|
# Check if the model is already downloaded
|
||||||
model_path = load_model_from_cache(repo_id, filename)
|
model_path = load_model_from_cache(repo_id, filename)
|
||||||
chat_model = None
|
chat_model = None
|
||||||
|
|
Loading…
Reference in a new issue