Reuse Search Models across Content Types to Reduce Memory Consumption

- Memory consumption now only scales with search models used, not with
  content types as well. Previously each content type had it's own
  copy of the search ML models. That'd result in 300+ Mb per enabled
  content type

- Split model state into 2 separate state objects, `search_models' and
  `content_index'.
  This allows loading text_search and image_search models first and then
  reusing them across all content_types in content_index

- This should cut down memory utilization quite a bit for most users.
  I see a ~50% drop in memory utilization.

  This will, of course, vary for each user based on the amount of
  content indexed vs number of plugins enabled

- This does not solve the RAM utilization scaling with size of the index.
  As the whole content index is still kept in RAM while Khoj is running

Should help with #195, #301 and #303
This commit is contained in:
Debanjum Singh Solanky 2023-07-14 01:07:44 -07:00
parent c2249eadb2
commit 86e2bec9a0
8 changed files with 217 additions and 142 deletions

View file

@ -20,9 +20,15 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils import constants, state
from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from khoj.utils.config import (
ContentIndex,
SearchType,
SearchModels,
ProcessorConfigModel,
ConversationProcessorConfigModel,
)
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
from khoj.utils.rawconfig import FullConfig, ProcessorConfig
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
@ -49,12 +55,20 @@ def configure_server(args, required=False):
# Initialize Processor from Config
state.processor_config = configure_processor(args.config.processor)
# Initialize the search type and model from Config
# Initialize Search Models from Config
state.search_index_lock.acquire()
state.SearchType = configure_search_types(state.config)
state.model = configure_search(state.model, state.config, args.regenerate)
state.search_models = configure_search(state.search_models, state.config.search_type)
state.search_index_lock.release()
# Initialize Content from Config
if state.search_models:
state.search_index_lock.acquire()
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, args.regenerate
)
state.search_index_lock.release()
def configure_routes(app):
# Import APIs here to setup search types before while configuring server
@ -73,7 +87,9 @@ if not state.demo:
@schedule.repeat(schedule.every(61).minutes)
def update_search_index():
state.search_index_lock.acquire()
state.model = configure_search(state.model, state.config, regenerate=False)
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate=False
)
state.search_index_lock.release()
logger.info("📬 Search index updated via Scheduler")
@ -90,94 +106,116 @@ def configure_search_types(config: FullConfig):
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None):
if config is None or config.content_type is None or config.search_type is None:
logger.warning("🚨 No Content or Search type is configured.")
return
def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]:
# Run Validation Checks
if search_config is None:
logger.warning("🚨 No Search type is configured.")
return None
if search_models is None:
search_models = SearchModels()
if model is None:
model = SearchModels()
# Initialize Search Models
if search_config.asymmetric:
logger.info("🔍 📜 Setting up text search model")
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
if search_config.image:
logger.info("🔍 🌄 Setting up image search model")
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
def configure_content(
content_index: Optional[ContentIndex],
content_config: Optional[ContentConfig],
search_models: SearchModels,
regenerate: bool,
t: Optional[state.SearchType] = None,
) -> Optional[ContentIndex]:
# Run Validation Checks
if content_config is None:
logger.warning("🚨 No Content type is configured.")
return None
if content_index is None:
content_index = ContentIndex()
try:
# Initialize Org Notes Search
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric:
if (t == state.SearchType.Org or t == None) and content_config.org and search_models.text_search:
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
model.org_search = text_search.setup(
content_index.org = text_search.setup(
OrgToJsonl,
config.content_type.org,
search_config=config.search_type.asymmetric,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Markdown Search
if (
(t == state.SearchType.Markdown or t == None)
and config.content_type.markdown
and config.search_type.asymmetric
):
if (t == state.SearchType.Markdown or t == None) and content_config.markdown and search_models.text_search:
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup(
content_index.markdown = text_search.setup(
MarkdownToJsonl,
config.content_type.markdown,
search_config=config.search_type.asymmetric,
content_config.markdown,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize PDF Search
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric:
if (t == state.SearchType.Pdf or t == None) and content_config.pdf and search_models.text_search:
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
model.pdf_search = text_search.setup(
content_index.pdf = text_search.setup(
PdfToJsonl,
config.content_type.pdf,
search_config=config.search_type.asymmetric,
content_config.pdf,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Image Search
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image:
if (t == state.SearchType.Image or t == None) and content_config.image and search_models.image_search:
logger.info("🌄 Setting up search for images")
# Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup(
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
)
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric:
if (t == state.SearchType.Github or t == None) and content_config.github and search_models.text_search:
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
model.github_search = text_search.setup(
content_index.github = text_search.setup(
GithubToJsonl,
config.content_type.github,
search_config=config.search_type.asymmetric,
content_config.github,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize External Plugin Search
if (t == None or t in state.SearchType) and config.content_type.plugins:
if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
logger.info("🔌 Setting up search for plugins")
model.plugin_search = {}
for plugin_type, plugin_config in config.content_type.plugins.items():
model.plugin_search[plugin_type] = text_search.setup(
content_index.plugins = {}
for plugin_type, plugin_config in content_config.plugins.items():
content_index.plugins[plugin_type] = text_search.setup(
JsonlToJsonl,
plugin_config,
search_config=config.search_type.asymmetric,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
# Initialize Notion Search
if (t == None or t in state.SearchType) and config.content_type.notion:
if (t == None or t in state.SearchType) and content_config.notion and search_models.text_search:
logger.info("🔌 Setting up search for notion")
model.notion_search = text_search.setup(
content_index.notion = text_search.setup(
NotionToJsonl,
config.content_type.notion,
search_config=config.search_type.asymmetric,
content_config.notion,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
)
@ -189,7 +227,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
# Invalidate Query Cache
state.query_cache = LRU()
return model
return content_index
def configure_processor(processor_config: ProcessorConfig):

View file

@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Header, Request
from sentence_transformers import util
# Internal Packages
from khoj.configure import configure_processor, configure_search
from khoj.configure import configure_content, configure_processor, configure_search
from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@ -102,17 +102,17 @@ if not state.demo:
state.config.content_type[content_type] = None
if content_type == "github":
state.model.github_search = None
state.content_index.github = None
elif content_type == "notion":
state.model.notion_search = None
state.content_index.notion = None
elif content_type == "plugins":
state.model.plugin_search = None
state.content_index.plugins = None
elif content_type == "pdf":
state.model.pdf_search = None
state.content_index.pdf = None
elif content_type == "markdown":
state.model.markdown_search = None
state.content_index.markdown = None
elif content_type == "org":
state.model.org_search = None
state.content_index.org = None
try:
save_config_to_file_updated_state()
@ -182,7 +182,7 @@ def get_config_types():
for search_type in SearchType
if (
search_type.value in configured_content_types
and getattr(state.model, f"{search_type.value}_search") is not None
and getattr(state.content_index, search_type.value) is not None
)
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
or search_type == SearchType.All
@ -210,7 +210,7 @@ async def search(
if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search")
return results
if not state.model or not any(state.model.__dict__.values()):
if not state.search_models or not any(state.search_models.__dict__.values()):
logger.warning(f"No search models loaded. Configure a search model before initiating search")
return results
@ -234,7 +234,7 @@ async def search(
encoded_asymmetric_query = None
if t == SearchType.All or t != SearchType.Image:
text_search_models: List[TextSearchModel] = [
model for model in state.model.__dict__.values() if isinstance(model, TextSearchModel)
model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
]
if text_search_models:
with timer("Encoding query took", logger=logger):
@ -247,13 +247,14 @@ async def search(
)
with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
# query org-mode notes
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.org_search,
state.search_models.text_search,
state.content_index.org,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@ -261,13 +262,18 @@ async def search(
)
]
if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search:
if (
(t == SearchType.Markdown or t == SearchType.All)
and state.content_index.markdown
and state.search_models.text_search
):
# query markdown notes
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.markdown_search,
state.search_models.text_search,
state.content_index.markdown,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@ -275,13 +281,18 @@ async def search(
)
]
if (t == SearchType.Github or t == SearchType.All) and state.model.github_search:
if (
(t == SearchType.Github or t == SearchType.All)
and state.content_index.github
and state.search_models.text_search
):
# query github issues
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.github_search,
state.search_models.text_search,
state.content_index.github,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@ -289,13 +300,14 @@ async def search(
)
]
if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search:
if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
# query pdf files
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.pdf_search,
state.search_models.text_search,
state.content_index.pdf,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@ -303,26 +315,38 @@ async def search(
)
]
if (t == SearchType.Image) and state.model.image_search:
if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
# query images
search_futures += [
executor.submit(
image_search.query,
user_query,
results_count,
state.model.image_search,
state.search_models.image_search,
state.content_index.image,
score_threshold=score_threshold,
)
]
if (t == SearchType.All or t in SearchType) and state.model.plugin_search:
if (
(t == SearchType.All or t in SearchType)
and state.content_index.plugins
and state.search_models.plugin_search
):
# query specified plugin type
# Get plugin content, search model for specified search type, or the first one if none specified
plugin_search = state.search_models.plugin_search.get(t.value) or next(
iter(state.search_models.plugin_search.values())
)
plugin_content = state.content_index.plugins.get(t.value) or next(
iter(state.content_index.plugins.values())
)
search_futures += [
executor.submit(
text_search.query,
user_query,
# Get plugin search model for specified search type, or the first one if none specified
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
plugin_search,
plugin_content,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@ -330,13 +354,18 @@ async def search(
)
]
if (t == SearchType.Notion or t == SearchType.All) and state.model.notion_search:
if (
(t == SearchType.Notion or t == SearchType.All)
and state.content_index.notion
and state.search_models.text_search
):
# query notion pages
search_futures += [
executor.submit(
text_search.query,
user_query,
state.model.notion_search,
state.search_models.text_search,
state.content_index.notion,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
@ -347,13 +376,13 @@ async def search(
# Query across each requested content types in parallel
with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures):
if t == SearchType.Image:
if t == SearchType.Image and state.content_index.image:
hits = await search_future.result()
output_directory = constants.web_directory / "images"
# Collate results
results += image_search.collate_results(
hits,
image_names=state.model.image_search.image_names,
image_names=state.content_index.image.image_names,
output_directory=output_directory,
image_files_url="/static/images",
count=results_count,
@ -404,7 +433,12 @@ def update(
try:
state.search_index_lock.acquire()
try:
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t)
if state.config and state.config.search_type:
state.search_models = configure_search(state.search_models, state.config.search_type)
if state.search_models:
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
)
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))

View file

@ -12,10 +12,12 @@ from sentence_transformers import SentenceTransformer, util
from PIL import Image
from tqdm import trange
import torch
from khoj.utils import state
# Internal Packages
from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
from khoj.utils.config import ImageSearchModel
from khoj.utils.config import ImageContent, ImageSearchModel
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
@ -40,7 +42,7 @@ def initialize_model(search_config: ImageSearchConfig):
model_type=search_config.encoder_type or SentenceTransformer,
)
return encoder
return ImageSearchModel(encoder)
def extract_entries(image_directories):
@ -143,7 +145,9 @@ def extract_metadata(image_name):
return image_processed_metadata
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
async def query(
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
):
# Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
@ -158,21 +162,21 @@ async def query(raw_query, count, model: ImageSearchModel, score_threshold: floa
# Now we encode the query (which can either be an image or a text string)
with timer("Query Encode Time", logger):
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger):
image_hits = {
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
}
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings:
if content.image_metadata_embeddings:
with timer("Metadata Search Time", logger):
metadata_hits = {
result["corpus_id"]: result["score"]
for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]
for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0]
}
# Sum metadata, image scores of the highest ranked images
@ -239,10 +243,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
return results
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
# Initialize Model
encoder = initialize_model(search_config)
def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent:
# Extract Entries
absolute_image_files, filtered_image_files = set(), set()
if config.input_directories:
@ -268,4 +269,4 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
use_xmp_metadata=config.use_xmp_metadata,
)
return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder)
return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings)

View file

@ -13,7 +13,7 @@ from khoj.search_filter.base_filter import BaseFilter
# Internal Packages
from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
from khoj.utils.config import TextSearchModel
from khoj.utils.config import TextContent, TextSearchModel
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
from khoj.utils.jsonl import load_jsonl
@ -26,9 +26,6 @@ def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text"
torch.set_num_threads(4)
# Number of entries we want to retrieve with the bi-encoder
top_k = 15
# If model directory is configured
if search_config.model_directory:
# Convert model directory to absolute path
@ -52,7 +49,7 @@ def initialize_model(search_config: TextSearchConfig):
device=f"{state.device}",
)
return bi_encoder, cross_encoder, top_k
return TextSearchModel(bi_encoder, cross_encoder)
def extract_entries(jsonl_file) -> List[Entry]:
@ -67,7 +64,7 @@ def compute_embeddings(
new_entries = []
# Load pre-computed embeddings from file if exists and update them if required
if embeddings_file.exists() and not regenerate:
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
# Encode any new entries in the corpus and update corpus embeddings
@ -104,17 +101,18 @@ def compute_embeddings(
async def query(
raw_query: str,
model: TextSearchModel,
search_model: TextSearchModel,
content: TextContent,
question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False,
score_threshold: float = -math.inf,
dedupe: bool = True,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
# Filter query, entries and embeddings before semantic search
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
# If no entries left after filtering, return empty results
if entries is None or len(entries) == 0:
@ -127,18 +125,17 @@ async def query(
# Encode the query using the bi-encoder
if question_embedding is None:
with timer("Query Encode Time", logger, state.device):
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
# Find relevant entries for the query
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
with timer("Search Time", logger, state.device):
hits = util.semantic_search(
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
)[0]
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
# Score all retrieved entries using the cross-encoder
if rank_results:
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
if rank_results and search_model.cross_encoder:
hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
# Filter results by score threshold
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
@ -173,13 +170,10 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
def setup(
text_to_jsonl: Type[TextToJsonl],
config: TextConfigBase,
search_config: TextSearchConfig,
bi_encoder: BaseEncoder,
regenerate: bool,
filters: List[BaseFilter] = [],
) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
) -> TextContent:
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
previous_entries = (
@ -192,7 +186,6 @@ def setup(
if is_none_or_empty(entries):
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
raise ValueError(f"No valid entries found in specified files: {config_params}")
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
@ -203,7 +196,7 @@ def setup(
for filter in filters:
filter.load(entries, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
return TextContent(entries, corpus_embeddings, filters)
def apply_filters(

View file

@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints
from enum import Enum
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
# External Packages
import torch
@ -30,42 +30,48 @@ class ProcessorType(str, Enum):
Conversation = "conversation"
@dataclass
class TextContent:
entries: List[Entry]
corpus_embeddings: torch.Tensor
filters: List[BaseFilter]
@dataclass
class ImageContent:
image_names: List[str]
image_embeddings: torch.Tensor
image_metadata_embeddings: torch.Tensor
@dataclass
class TextSearchModel:
def __init__(
self,
entries: List[Entry],
corpus_embeddings: torch.Tensor,
bi_encoder: BaseEncoder,
cross_encoder: CrossEncoder,
filters: List[BaseFilter],
top_k,
):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder
self.cross_encoder = cross_encoder
self.filters = filters
self.top_k = top_k
bi_encoder: BaseEncoder
cross_encoder: Optional[CrossEncoder] = None
top_k: Optional[int] = 15
@dataclass
class ImageSearchModel:
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
self.image_encoder = image_encoder
self.image_names = image_names
self.image_embeddings = image_embeddings
self.image_metadata_embeddings = image_metadata_embeddings
self.image_encoder = image_encoder
image_encoder: BaseEncoder
@dataclass
class ContentIndex:
org: Optional[TextContent] = None
markdown: Optional[TextContent] = None
pdf: Optional[TextContent] = None
github: Optional[TextContent] = None
notion: Optional[TextContent] = None
image: Optional[ImageContent] = None
plugins: Optional[Dict[str, TextContent]] = None
@dataclass
class SearchModels:
org_search: Union[TextSearchModel, None] = None
markdown_search: Union[TextSearchModel, None] = None
pdf_search: Union[TextSearchModel, None] = None
image_search: Union[ImageSearchModel, None] = None
github_search: Union[TextSearchModel, None] = None
notion_search: Union[TextSearchModel, None] = None
plugin_search: Union[Dict[str, TextSearchModel], None] = None
text_search: Optional[TextSearchModel] = None
image_search: Optional[ImageSearchModel] = None
plugin_search: Optional[Dict[str, TextSearchModel]] = None
class ConversationProcessorConfigModel:

View file

@ -20,7 +20,7 @@ from khoj.utils import constants
if TYPE_CHECKING:
# External Packages
from sentence_transformers import CrossEncoder
from sentence_transformers import SentenceTransformer, CrossEncoder
# Internal Packages
from khoj.utils.models import BaseEncoder
@ -64,7 +64,9 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict
def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]:
def load_model(
model_name: str, model_type, model_dir=None, device: str = None
) -> Union[BaseEncoder, SentenceTransformer, CrossEncoder]:
"Load model from disk or huggingface"
# Construct model path
logger = logging.getLogger(__name__)

View file

@ -119,9 +119,9 @@ class AppConfig(ConfigBase):
class FullConfig(ConfigBase):
content_type: Optional[ContentConfig]
search_type: Optional[SearchConfig]
processor: Optional[ProcessorConfig]
content_type: Optional[ContentConfig] = None
search_type: Optional[SearchConfig] = None
processor: Optional[ProcessorConfig] = None
app: Optional[AppConfig] = AppConfig(should_log_telemetry=True)

View file

@ -9,13 +9,14 @@ from pathlib import Path
# Internal Packages
from khoj.utils import config as utils_config
from khoj.utils.config import SearchModels, ProcessorConfigModel
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import FullConfig
# Application Global State
config = FullConfig()
model = SearchModels()
search_models = SearchModels()
content_index = ContentIndex()
processor_config = ProcessorConfigModel()
config_file: Path = None
verbose: int = 0