mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Reuse Search Models across Content Types to Reduce Memory Consumption
- Memory consumption now only scales with search models used, not with content types as well. Previously each content type had it's own copy of the search ML models. That'd result in 300+ Mb per enabled content type - Split model state into 2 separate state objects, `search_models' and `content_index'. This allows loading text_search and image_search models first and then reusing them across all content_types in content_index - This should cut down memory utilization quite a bit for most users. I see a ~50% drop in memory utilization. This will, of course, vary for each user based on the amount of content indexed vs number of plugins enabled - This does not solve the RAM utilization scaling with size of the index. As the whole content index is still kept in RAM while Khoj is running Should help with #195, #301 and #303
This commit is contained in:
parent
c2249eadb2
commit
86e2bec9a0
8 changed files with 217 additions and 142 deletions
|
@ -20,9 +20,15 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
|||
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
||||
from khoj.utils.config import (
|
||||
ContentIndex,
|
||||
SearchType,
|
||||
SearchModels,
|
||||
ProcessorConfigModel,
|
||||
ConversationProcessorConfigModel,
|
||||
)
|
||||
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
|
||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig
|
||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
|
@ -49,12 +55,20 @@ def configure_server(args, required=False):
|
|||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(args.config.processor)
|
||||
|
||||
# Initialize the search type and model from Config
|
||||
# Initialize Search Models from Config
|
||||
state.search_index_lock.acquire()
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.model = configure_search(state.model, state.config, args.regenerate)
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
state.search_index_lock.release()
|
||||
|
||||
# Initialize Content from Config
|
||||
if state.search_models:
|
||||
state.search_index_lock.acquire()
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, args.regenerate
|
||||
)
|
||||
state.search_index_lock.release()
|
||||
|
||||
|
||||
def configure_routes(app):
|
||||
# Import APIs here to setup search types before while configuring server
|
||||
|
@ -73,7 +87,9 @@ if not state.demo:
|
|||
@schedule.repeat(schedule.every(61).minutes)
|
||||
def update_search_index():
|
||||
state.search_index_lock.acquire()
|
||||
state.model = configure_search(state.model, state.config, regenerate=False)
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, regenerate=False
|
||||
)
|
||||
state.search_index_lock.release()
|
||||
logger.info("📬 Search index updated via Scheduler")
|
||||
|
||||
|
@ -90,94 +106,116 @@ def configure_search_types(config: FullConfig):
|
|||
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
||||
|
||||
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None):
|
||||
if config is None or config.content_type is None or config.search_type is None:
|
||||
logger.warning("🚨 No Content or Search type is configured.")
|
||||
return
|
||||
def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]:
|
||||
# Run Validation Checks
|
||||
if search_config is None:
|
||||
logger.warning("🚨 No Search type is configured.")
|
||||
return None
|
||||
if search_models is None:
|
||||
search_models = SearchModels()
|
||||
|
||||
if model is None:
|
||||
model = SearchModels()
|
||||
# Initialize Search Models
|
||||
if search_config.asymmetric:
|
||||
logger.info("🔍 📜 Setting up text search model")
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
|
||||
if search_config.image:
|
||||
logger.info("🔍 🌄 Setting up image search model")
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
return search_models
|
||||
|
||||
|
||||
def configure_content(
|
||||
content_index: Optional[ContentIndex],
|
||||
content_config: Optional[ContentConfig],
|
||||
search_models: SearchModels,
|
||||
regenerate: bool,
|
||||
t: Optional[state.SearchType] = None,
|
||||
) -> Optional[ContentIndex]:
|
||||
# Run Validation Checks
|
||||
if content_config is None:
|
||||
logger.warning("🚨 No Content type is configured.")
|
||||
return None
|
||||
if content_index is None:
|
||||
content_index = ContentIndex()
|
||||
|
||||
try:
|
||||
# Initialize Org Notes Search
|
||||
if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric:
|
||||
if (t == state.SearchType.Org or t == None) and content_config.org and search_models.text_search:
|
||||
logger.info("🦄 Setting up search for orgmode notes")
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.org_search = text_search.setup(
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl,
|
||||
config.content_type.org,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (
|
||||
(t == state.SearchType.Markdown or t == None)
|
||||
and config.content_type.markdown
|
||||
and config.search_type.asymmetric
|
||||
):
|
||||
if (t == state.SearchType.Markdown or t == None) and content_config.markdown and search_models.text_search:
|
||||
logger.info("💎 Setting up search for markdown notes")
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
model.markdown_search = text_search.setup(
|
||||
content_index.markdown = text_search.setup(
|
||||
MarkdownToJsonl,
|
||||
config.content_type.markdown,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.markdown,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize PDF Search
|
||||
if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric:
|
||||
if (t == state.SearchType.Pdf or t == None) and content_config.pdf and search_models.text_search:
|
||||
logger.info("🖨️ Setting up search for pdf")
|
||||
# Extract Entries, Generate PDF Embeddings
|
||||
model.pdf_search = text_search.setup(
|
||||
content_index.pdf = text_search.setup(
|
||||
PdfToJsonl,
|
||||
config.content_type.pdf,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.pdf,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image:
|
||||
if (t == state.SearchType.Image or t == None) and content_config.image and search_models.image_search:
|
||||
logger.info("🌄 Setting up search for images")
|
||||
# Extract Entries, Generate Image Embeddings
|
||||
model.image_search = image_search.setup(
|
||||
config.content_type.image, search_config=config.search_type.image, regenerate=regenerate
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=regenerate
|
||||
)
|
||||
|
||||
if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric:
|
||||
if (t == state.SearchType.Github or t == None) and content_config.github and search_models.text_search:
|
||||
logger.info("🐙 Setting up search for github")
|
||||
# Extract Entries, Generate Github Embeddings
|
||||
model.github_search = text_search.setup(
|
||||
content_index.github = text_search.setup(
|
||||
GithubToJsonl,
|
||||
config.content_type.github,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.github,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize External Plugin Search
|
||||
if (t == None or t in state.SearchType) and config.content_type.plugins:
|
||||
if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search:
|
||||
logger.info("🔌 Setting up search for plugins")
|
||||
model.plugin_search = {}
|
||||
for plugin_type, plugin_config in config.content_type.plugins.items():
|
||||
model.plugin_search[plugin_type] = text_search.setup(
|
||||
content_index.plugins = {}
|
||||
for plugin_type, plugin_config in content_config.plugins.items():
|
||||
content_index.plugins[plugin_type] = text_search.setup(
|
||||
JsonlToJsonl,
|
||||
plugin_config,
|
||||
search_config=config.search_type.asymmetric,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
||||
# Initialize Notion Search
|
||||
if (t == None or t in state.SearchType) and config.content_type.notion:
|
||||
if (t == None or t in state.SearchType) and content_config.notion and search_models.text_search:
|
||||
logger.info("🔌 Setting up search for notion")
|
||||
model.notion_search = text_search.setup(
|
||||
content_index.notion = text_search.setup(
|
||||
NotionToJsonl,
|
||||
config.content_type.notion,
|
||||
search_config=config.search_type.asymmetric,
|
||||
content_config.notion,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
)
|
||||
|
@ -189,7 +227,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
# Invalidate Query Cache
|
||||
state.query_cache = LRU()
|
||||
|
||||
return model
|
||||
return content_index
|
||||
|
||||
|
||||
def configure_processor(processor_config: ProcessorConfig):
|
||||
|
|
|
@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Header, Request
|
|||
from sentence_transformers import util
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_search
|
||||
from khoj.configure import configure_content, configure_processor, configure_search
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
|
@ -102,17 +102,17 @@ if not state.demo:
|
|||
state.config.content_type[content_type] = None
|
||||
|
||||
if content_type == "github":
|
||||
state.model.github_search = None
|
||||
state.content_index.github = None
|
||||
elif content_type == "notion":
|
||||
state.model.notion_search = None
|
||||
state.content_index.notion = None
|
||||
elif content_type == "plugins":
|
||||
state.model.plugin_search = None
|
||||
state.content_index.plugins = None
|
||||
elif content_type == "pdf":
|
||||
state.model.pdf_search = None
|
||||
state.content_index.pdf = None
|
||||
elif content_type == "markdown":
|
||||
state.model.markdown_search = None
|
||||
state.content_index.markdown = None
|
||||
elif content_type == "org":
|
||||
state.model.org_search = None
|
||||
state.content_index.org = None
|
||||
|
||||
try:
|
||||
save_config_to_file_updated_state()
|
||||
|
@ -182,7 +182,7 @@ def get_config_types():
|
|||
for search_type in SearchType
|
||||
if (
|
||||
search_type.value in configured_content_types
|
||||
and getattr(state.model, f"{search_type.value}_search") is not None
|
||||
and getattr(state.content_index, search_type.value) is not None
|
||||
)
|
||||
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
|
||||
or search_type == SearchType.All
|
||||
|
@ -210,7 +210,7 @@ async def search(
|
|||
if q is None or q == "":
|
||||
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||
return results
|
||||
if not state.model or not any(state.model.__dict__.values()):
|
||||
if not state.search_models or not any(state.search_models.__dict__.values()):
|
||||
logger.warning(f"No search models loaded. Configure a search model before initiating search")
|
||||
return results
|
||||
|
||||
|
@ -234,7 +234,7 @@ async def search(
|
|||
encoded_asymmetric_query = None
|
||||
if t == SearchType.All or t != SearchType.Image:
|
||||
text_search_models: List[TextSearchModel] = [
|
||||
model for model in state.model.__dict__.values() if isinstance(model, TextSearchModel)
|
||||
model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
|
||||
]
|
||||
if text_search_models:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
|
@ -247,13 +247,14 @@ async def search(
|
|||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if (t == SearchType.Org or t == SearchType.All) and state.model.org_search:
|
||||
if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
|
||||
# query org-mode notes
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.org_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.org,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
|
@ -261,13 +262,18 @@ async def search(
|
|||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search:
|
||||
if (
|
||||
(t == SearchType.Markdown or t == SearchType.All)
|
||||
and state.content_index.markdown
|
||||
and state.search_models.text_search
|
||||
):
|
||||
# query markdown notes
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.markdown_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.markdown,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
|
@ -275,13 +281,18 @@ async def search(
|
|||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Github or t == SearchType.All) and state.model.github_search:
|
||||
if (
|
||||
(t == SearchType.Github or t == SearchType.All)
|
||||
and state.content_index.github
|
||||
and state.search_models.text_search
|
||||
):
|
||||
# query github issues
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.github_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.github,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
|
@ -289,13 +300,14 @@ async def search(
|
|||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search:
|
||||
if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
|
||||
# query pdf files
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.pdf_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.pdf,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
|
@ -303,26 +315,38 @@ async def search(
|
|||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Image) and state.model.image_search:
|
||||
if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
|
||||
# query images
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
image_search.query,
|
||||
user_query,
|
||||
results_count,
|
||||
state.model.image_search,
|
||||
state.search_models.image_search,
|
||||
state.content_index.image,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.All or t in SearchType) and state.model.plugin_search:
|
||||
if (
|
||||
(t == SearchType.All or t in SearchType)
|
||||
and state.content_index.plugins
|
||||
and state.search_models.plugin_search
|
||||
):
|
||||
# query specified plugin type
|
||||
# Get plugin content, search model for specified search type, or the first one if none specified
|
||||
plugin_search = state.search_models.plugin_search.get(t.value) or next(
|
||||
iter(state.search_models.plugin_search.values())
|
||||
)
|
||||
plugin_content = state.content_index.plugins.get(t.value) or next(
|
||||
iter(state.content_index.plugins.values())
|
||||
)
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
# Get plugin search model for specified search type, or the first one if none specified
|
||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||
plugin_search,
|
||||
plugin_content,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
|
@ -330,13 +354,18 @@ async def search(
|
|||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Notion or t == SearchType.All) and state.model.notion_search:
|
||||
if (
|
||||
(t == SearchType.Notion or t == SearchType.All)
|
||||
and state.content_index.notion
|
||||
and state.search_models.text_search
|
||||
):
|
||||
# query notion pages
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.model.notion_search,
|
||||
state.search_models.text_search,
|
||||
state.content_index.notion,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
|
@ -347,13 +376,13 @@ async def search(
|
|||
# Query across each requested content types in parallel
|
||||
with timer("Query took", logger):
|
||||
for search_future in concurrent.futures.as_completed(search_futures):
|
||||
if t == SearchType.Image:
|
||||
if t == SearchType.Image and state.content_index.image:
|
||||
hits = await search_future.result()
|
||||
output_directory = constants.web_directory / "images"
|
||||
# Collate results
|
||||
results += image_search.collate_results(
|
||||
hits,
|
||||
image_names=state.model.image_search.image_names,
|
||||
image_names=state.content_index.image.image_names,
|
||||
output_directory=output_directory,
|
||||
image_files_url="/static/images",
|
||||
count=results_count,
|
||||
|
@ -404,7 +433,12 @@ def update(
|
|||
try:
|
||||
state.search_index_lock.acquire()
|
||||
try:
|
||||
state.model = configure_search(state.model, state.config, regenerate=force or False, t=t)
|
||||
if state.config and state.config.search_type:
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
if state.search_models:
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
@ -12,10 +12,12 @@ from sentence_transformers import SentenceTransformer, util
|
|||
from PIL import Image
|
||||
from tqdm import trange
|
||||
import torch
|
||||
from khoj.utils import state
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
|
||||
from khoj.utils.config import ImageSearchModel
|
||||
from khoj.utils.config import ImageContent, ImageSearchModel
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
|
||||
|
||||
|
||||
|
@ -40,7 +42,7 @@ def initialize_model(search_config: ImageSearchConfig):
|
|||
model_type=search_config.encoder_type or SentenceTransformer,
|
||||
)
|
||||
|
||||
return encoder
|
||||
return ImageSearchModel(encoder)
|
||||
|
||||
|
||||
def extract_entries(image_directories):
|
||||
|
@ -143,7 +145,9 @@ def extract_metadata(image_name):
|
|||
return image_processed_metadata
|
||||
|
||||
|
||||
async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
||||
async def query(
|
||||
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
|
||||
):
|
||||
# Set query to image content if query is of form file:/path/to/file.png
|
||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
||||
|
@ -158,21 +162,21 @@ async def query(raw_query, count, model: ImageSearchModel, score_threshold: floa
|
|||
|
||||
# Now we encode the query (which can either be an image or a text string)
|
||||
with timer("Query Encode Time", logger):
|
||||
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||
with timer("Search Time", logger):
|
||||
image_hits = {
|
||||
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
|
||||
for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]
|
||||
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
|
||||
}
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||
if model.image_metadata_embeddings:
|
||||
if content.image_metadata_embeddings:
|
||||
with timer("Metadata Search Time", logger):
|
||||
metadata_hits = {
|
||||
result["corpus_id"]: result["score"]
|
||||
for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]
|
||||
for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0]
|
||||
}
|
||||
|
||||
# Sum metadata, image scores of the highest ranked images
|
||||
|
@ -239,10 +243,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
|||
return results
|
||||
|
||||
|
||||
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
|
||||
# Initialize Model
|
||||
encoder = initialize_model(search_config)
|
||||
|
||||
def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent:
|
||||
# Extract Entries
|
||||
absolute_image_files, filtered_image_files = set(), set()
|
||||
if config.input_directories:
|
||||
|
@ -268,4 +269,4 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
|
|||
use_xmp_metadata=config.use_xmp_metadata,
|
||||
)
|
||||
|
||||
return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder)
|
||||
return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings)
|
||||
|
|
|
@ -13,7 +13,7 @@ from khoj.search_filter.base_filter import BaseFilter
|
|||
# Internal Packages
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
|
||||
from khoj.utils.config import TextSearchModel
|
||||
from khoj.utils.config import TextContent, TextSearchModel
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
|
||||
from khoj.utils.jsonl import load_jsonl
|
||||
|
@ -26,9 +26,6 @@ def initialize_model(search_config: TextSearchConfig):
|
|||
"Initialize model for semantic search on text"
|
||||
torch.set_num_threads(4)
|
||||
|
||||
# Number of entries we want to retrieve with the bi-encoder
|
||||
top_k = 15
|
||||
|
||||
# If model directory is configured
|
||||
if search_config.model_directory:
|
||||
# Convert model directory to absolute path
|
||||
|
@ -52,7 +49,7 @@ def initialize_model(search_config: TextSearchConfig):
|
|||
device=f"{state.device}",
|
||||
)
|
||||
|
||||
return bi_encoder, cross_encoder, top_k
|
||||
return TextSearchModel(bi_encoder, cross_encoder)
|
||||
|
||||
|
||||
def extract_entries(jsonl_file) -> List[Entry]:
|
||||
|
@ -67,7 +64,7 @@ def compute_embeddings(
|
|||
new_entries = []
|
||||
# Load pre-computed embeddings from file if exists and update them if required
|
||||
if embeddings_file.exists() and not regenerate:
|
||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
|
||||
|
||||
# Encode any new entries in the corpus and update corpus embeddings
|
||||
|
@ -104,17 +101,18 @@ def compute_embeddings(
|
|||
|
||||
async def query(
|
||||
raw_query: str,
|
||||
model: TextSearchModel,
|
||||
search_model: TextSearchModel,
|
||||
content: TextContent,
|
||||
question_embedding: Union[torch.Tensor, None] = None,
|
||||
rank_results: bool = False,
|
||||
score_threshold: float = -math.inf,
|
||||
dedupe: bool = True,
|
||||
) -> Tuple[List[dict], List[Entry]]:
|
||||
"Search for entries that answer the query"
|
||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||
query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
|
||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
|
||||
|
||||
# If no entries left after filtering, return empty results
|
||||
if entries is None or len(entries) == 0:
|
||||
|
@ -127,18 +125,17 @@ async def query(
|
|||
# Encode the query using the bi-encoder
|
||||
if question_embedding is None:
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
|
||||
# Find relevant entries for the query
|
||||
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
|
||||
with timer("Search Time", logger, state.device):
|
||||
hits = util.semantic_search(
|
||||
question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score
|
||||
)[0]
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
|
||||
if rank_results and search_model.cross_encoder:
|
||||
hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
|
||||
|
||||
# Filter results by score threshold
|
||||
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
|
||||
|
@ -173,13 +170,10 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
|
|||
def setup(
|
||||
text_to_jsonl: Type[TextToJsonl],
|
||||
config: TextConfigBase,
|
||||
search_config: TextSearchConfig,
|
||||
bi_encoder: BaseEncoder,
|
||||
regenerate: bool,
|
||||
filters: List[BaseFilter] = [],
|
||||
) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
) -> TextContent:
|
||||
# Map notes in text files to (compressed) JSONL formatted file
|
||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||
previous_entries = (
|
||||
|
@ -192,7 +186,6 @@ def setup(
|
|||
if is_none_or_empty(entries):
|
||||
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
|
||||
raise ValueError(f"No valid entries found in specified files: {config_params}")
|
||||
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
||||
|
||||
# Compute or Load Embeddings
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
|
@ -203,7 +196,7 @@ def setup(
|
|||
for filter in filters:
|
||||
filter.load(entries, regenerate=regenerate)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
||||
return TextContent(entries, corpus_embeddings, filters)
|
||||
|
||||
|
||||
def apply_filters(
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints
|
|||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
@ -30,42 +30,48 @@ class ProcessorType(str, Enum):
|
|||
Conversation = "conversation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextContent:
|
||||
entries: List[Entry]
|
||||
corpus_embeddings: torch.Tensor
|
||||
filters: List[BaseFilter]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageContent:
|
||||
image_names: List[str]
|
||||
image_embeddings: torch.Tensor
|
||||
image_metadata_embeddings: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextSearchModel:
|
||||
def __init__(
|
||||
self,
|
||||
entries: List[Entry],
|
||||
corpus_embeddings: torch.Tensor,
|
||||
bi_encoder: BaseEncoder,
|
||||
cross_encoder: CrossEncoder,
|
||||
filters: List[BaseFilter],
|
||||
top_k,
|
||||
):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
self.cross_encoder = cross_encoder
|
||||
self.filters = filters
|
||||
self.top_k = top_k
|
||||
bi_encoder: BaseEncoder
|
||||
cross_encoder: Optional[CrossEncoder] = None
|
||||
top_k: Optional[int] = 15
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageSearchModel:
|
||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
|
||||
self.image_encoder = image_encoder
|
||||
self.image_names = image_names
|
||||
self.image_embeddings = image_embeddings
|
||||
self.image_metadata_embeddings = image_metadata_embeddings
|
||||
self.image_encoder = image_encoder
|
||||
image_encoder: BaseEncoder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentIndex:
|
||||
org: Optional[TextContent] = None
|
||||
markdown: Optional[TextContent] = None
|
||||
pdf: Optional[TextContent] = None
|
||||
github: Optional[TextContent] = None
|
||||
notion: Optional[TextContent] = None
|
||||
image: Optional[ImageContent] = None
|
||||
plugins: Optional[Dict[str, TextContent]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchModels:
|
||||
org_search: Union[TextSearchModel, None] = None
|
||||
markdown_search: Union[TextSearchModel, None] = None
|
||||
pdf_search: Union[TextSearchModel, None] = None
|
||||
image_search: Union[ImageSearchModel, None] = None
|
||||
github_search: Union[TextSearchModel, None] = None
|
||||
notion_search: Union[TextSearchModel, None] = None
|
||||
plugin_search: Union[Dict[str, TextSearchModel], None] = None
|
||||
text_search: Optional[TextSearchModel] = None
|
||||
image_search: Optional[ImageSearchModel] = None
|
||||
plugin_search: Optional[Dict[str, TextSearchModel]] = None
|
||||
|
||||
|
||||
class ConversationProcessorConfigModel:
|
||||
|
|
|
@ -20,7 +20,7 @@ from khoj.utils import constants
|
|||
|
||||
if TYPE_CHECKING:
|
||||
# External Packages
|
||||
from sentence_transformers import CrossEncoder
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.models import BaseEncoder
|
||||
|
@ -64,7 +64,9 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
|
|||
return merged_dict
|
||||
|
||||
|
||||
def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]:
|
||||
def load_model(
|
||||
model_name: str, model_type, model_dir=None, device: str = None
|
||||
) -> Union[BaseEncoder, SentenceTransformer, CrossEncoder]:
|
||||
"Load model from disk or huggingface"
|
||||
# Construct model path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -119,9 +119,9 @@ class AppConfig(ConfigBase):
|
|||
|
||||
|
||||
class FullConfig(ConfigBase):
|
||||
content_type: Optional[ContentConfig]
|
||||
search_type: Optional[SearchConfig]
|
||||
processor: Optional[ProcessorConfig]
|
||||
content_type: Optional[ContentConfig] = None
|
||||
search_type: Optional[SearchConfig] = None
|
||||
processor: Optional[ProcessorConfig] = None
|
||||
app: Optional[AppConfig] = AppConfig(should_log_telemetry=True)
|
||||
|
||||
|
||||
|
|
|
@ -9,13 +9,14 @@ from pathlib import Path
|
|||
|
||||
# Internal Packages
|
||||
from khoj.utils import config as utils_config
|
||||
from khoj.utils.config import SearchModels, ProcessorConfigModel
|
||||
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
|
||||
from khoj.utils.helpers import LRU
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
|
||||
# Application Global State
|
||||
config = FullConfig()
|
||||
model = SearchModels()
|
||||
search_models = SearchModels()
|
||||
content_index = ContentIndex()
|
||||
processor_config = ProcessorConfigModel()
|
||||
config_file: Path = None
|
||||
verbose: int = 0
|
||||
|
|
Loading…
Reference in a new issue