mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Save Image Search Model to Disk
This commit is contained in:
parent
934ec233b0
commit
510faa1904
4 changed files with 15 additions and 7 deletions
|
@ -36,6 +36,7 @@ search-type:
|
||||||
|
|
||||||
image:
|
image:
|
||||||
encoder: "clip-ViT-B-32"
|
encoder: "clip-ViT-B-32"
|
||||||
|
model_directory: "tests/data/.image_encoder"
|
||||||
|
|
||||||
processor:
|
processor:
|
||||||
conversation:
|
conversation:
|
||||||
|
|
|
@ -145,7 +145,7 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None
|
||||||
# Initialize Image Search
|
# Initialize Image Search
|
||||||
if (t == SearchType.Image or t == None) and config.content_type.image:
|
if (t == SearchType.Image or t == None) and config.content_type.image:
|
||||||
# Extract Entries, Generate Image Embeddings
|
# Extract Entries, Generate Image Embeddings
|
||||||
model.image_search = image_search.setup(config.content_type.image, regenerate=regenerate, verbose=verbose)
|
model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate, verbose=verbose)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -10,16 +10,22 @@ from tqdm import trange
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils.helpers import resolve_absolute_path
|
from src.utils.helpers import resolve_absolute_path, load_model
|
||||||
import src.utils.exiftool as exiftool
|
import src.utils.exiftool as exiftool
|
||||||
from src.utils.config import ImageSearchModel
|
from src.utils.config import ImageSearchModel
|
||||||
from src.utils.rawconfig import ImageSearchConfig
|
from src.utils.rawconfig import ImageSearchConfig, ImageSearchTypeConfig
|
||||||
|
|
||||||
|
|
||||||
def initialize_model():
|
def initialize_model(search_config: ImageSearchTypeConfig):
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
torch.set_num_threads(4)
|
torch.set_num_threads(4)
|
||||||
encoder = SentenceTransformer('sentence-transformers/clip-ViT-B-32') #Load the CLIP model
|
|
||||||
|
# Load the CLIP model
|
||||||
|
encoder = load_model(
|
||||||
|
model_dir = search_config.model_directory,
|
||||||
|
model_name = search_config.encoder,
|
||||||
|
model_type = SentenceTransformer)
|
||||||
|
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,9 +160,9 @@ def collate_results(hits, image_names, image_directory, count=5):
|
||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
||||||
|
|
||||||
def setup(config: ImageSearchConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel:
|
def setup(config: ImageSearchConfig, search_config: ImageSearchTypeConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel:
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
encoder = initialize_model()
|
encoder = initialize_model(search_config)
|
||||||
|
|
||||||
# Extract Entries
|
# Extract Entries
|
||||||
image_directory = resolve_absolute_path(config.input_directory, strict=True)
|
image_directory = resolve_absolute_path(config.input_directory, strict=True)
|
||||||
|
|
|
@ -49,6 +49,7 @@ class AsymmetricConfig(ConfigBase):
|
||||||
|
|
||||||
class ImageSearchTypeConfig(ConfigBase):
|
class ImageSearchTypeConfig(ConfigBase):
|
||||||
encoder: Optional[str]
|
encoder: Optional[str]
|
||||||
|
model_directory: Optional[Path]
|
||||||
|
|
||||||
class SearchTypeConfig(ConfigBase):
|
class SearchTypeConfig(ConfigBase):
|
||||||
asymmetric: Optional[AsymmetricConfig]
|
asymmetric: Optional[AsymmetricConfig]
|
||||||
|
|
Loading…
Reference in a new issue