From 86e2bec9a092b1f386ef0a43ff18aa67b4488ce6 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 14 Jul 2023 01:07:44 -0700 Subject: [PATCH 1/3] Reuse Search Models across Content Types to Reduce Memory Consumption - Memory consumption now only scales with search models used, not with content types as well. Previously each content type had it's own copy of the search ML models. That'd result in 300+ Mb per enabled content type - Split model state into 2 separate state objects, `search_models' and `content_index'. This allows loading text_search and image_search models first and then reusing them across all content_types in content_index - This should cut down memory utilization quite a bit for most users. I see a ~50% drop in memory utilization. This will, of course, vary for each user based on the amount of content indexed vs number of plugins enabled - This does not solve the RAM utilization scaling with size of the index. As the whole content index is still kept in RAM while Khoj is running Should help with #195, #301 and #303 --- src/khoj/configure.py | 126 +++++++++++++++++---------- src/khoj/routers/api.py | 90 +++++++++++++------ src/khoj/search_type/image_search.py | 25 +++--- src/khoj/search_type/text_search.py | 37 ++++---- src/khoj/utils/config.py | 64 ++++++++------ src/khoj/utils/helpers.py | 6 +- src/khoj/utils/rawconfig.py | 6 +- src/khoj/utils/state.py | 5 +- 8 files changed, 217 insertions(+), 142 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index a78a1bdf..c1f8ff04 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -20,9 +20,15 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl from khoj.processor.notion.notion_to_jsonl import NotionToJsonl from khoj.search_type import image_search, text_search from khoj.utils import constants, state -from khoj.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel +from khoj.utils.config import ( + ContentIndex, + SearchType, + SearchModels, + ProcessorConfigModel, + ConversationProcessorConfigModel, +) from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts -from khoj.utils.rawconfig import FullConfig, ProcessorConfig +from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.file_filter import FileFilter @@ -49,12 +55,20 @@ def configure_server(args, required=False): # Initialize Processor from Config state.processor_config = configure_processor(args.config.processor) - # Initialize the search type and model from Config + # Initialize Search Models from Config state.search_index_lock.acquire() state.SearchType = configure_search_types(state.config) - state.model = configure_search(state.model, state.config, args.regenerate) + state.search_models = configure_search(state.search_models, state.config.search_type) state.search_index_lock.release() + # Initialize Content from Config + if state.search_models: + state.search_index_lock.acquire() + state.content_index = configure_content( + state.content_index, state.config.content_type, state.search_models, args.regenerate + ) + state.search_index_lock.release() + def configure_routes(app): # Import APIs here to setup search types before while configuring server @@ -73,7 +87,9 @@ if not state.demo: @schedule.repeat(schedule.every(61).minutes) def update_search_index(): state.search_index_lock.acquire() - state.model = configure_search(state.model, state.config, regenerate=False) + state.content_index = configure_content( + state.content_index, state.config.content_type, state.search_models, regenerate=False + ) state.search_index_lock.release() logger.info("📬 Search index updated via Scheduler") @@ -90,94 +106,116 @@ def configure_search_types(config: FullConfig): return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types)) -def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: Optional[state.SearchType] = None): - if config is None or config.content_type is None or config.search_type is None: - logger.warning("🚨 No Content or Search type is configured.") - return +def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]: + # Run Validation Checks + if search_config is None: + logger.warning("🚨 No Search type is configured.") + return None + if search_models is None: + search_models = SearchModels() - if model is None: - model = SearchModels() + # Initialize Search Models + if search_config.asymmetric: + logger.info("🔍 📜 Setting up text search model") + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + + if search_config.image: + logger.info("🔍 🌄 Setting up image search model") + search_models.image_search = image_search.initialize_model(search_config.image) + + return search_models + + +def configure_content( + content_index: Optional[ContentIndex], + content_config: Optional[ContentConfig], + search_models: SearchModels, + regenerate: bool, + t: Optional[state.SearchType] = None, +) -> Optional[ContentIndex]: + # Run Validation Checks + if content_config is None: + logger.warning("🚨 No Content type is configured.") + return None + if content_index is None: + content_index = ContentIndex() try: # Initialize Org Notes Search - if (t == state.SearchType.Org or t == None) and config.content_type.org and config.search_type.asymmetric: + if (t == state.SearchType.Org or t == None) and content_config.org and search_models.text_search: logger.info("🦄 Setting up search for orgmode notes") # Extract Entries, Generate Notes Embeddings - model.org_search = text_search.setup( + content_index.org = text_search.setup( OrgToJsonl, - config.content_type.org, - search_config=config.search_type.asymmetric, + content_config.org, + search_models.text_search.bi_encoder, regenerate=regenerate, filters=[DateFilter(), WordFilter(), FileFilter()], ) # Initialize Markdown Search - if ( - (t == state.SearchType.Markdown or t == None) - and config.content_type.markdown - and config.search_type.asymmetric - ): + if (t == state.SearchType.Markdown or t == None) and content_config.markdown and search_models.text_search: logger.info("💎 Setting up search for markdown notes") # Extract Entries, Generate Markdown Embeddings - model.markdown_search = text_search.setup( + content_index.markdown = text_search.setup( MarkdownToJsonl, - config.content_type.markdown, - search_config=config.search_type.asymmetric, + content_config.markdown, + search_models.text_search.bi_encoder, regenerate=regenerate, filters=[DateFilter(), WordFilter(), FileFilter()], ) # Initialize PDF Search - if (t == state.SearchType.Pdf or t == None) and config.content_type.pdf and config.search_type.asymmetric: + if (t == state.SearchType.Pdf or t == None) and content_config.pdf and search_models.text_search: logger.info("🖨️ Setting up search for pdf") # Extract Entries, Generate PDF Embeddings - model.pdf_search = text_search.setup( + content_index.pdf = text_search.setup( PdfToJsonl, - config.content_type.pdf, - search_config=config.search_type.asymmetric, + content_config.pdf, + search_models.text_search.bi_encoder, regenerate=regenerate, filters=[DateFilter(), WordFilter(), FileFilter()], ) # Initialize Image Search - if (t == state.SearchType.Image or t == None) and config.content_type.image and config.search_type.image: + if (t == state.SearchType.Image or t == None) and content_config.image and search_models.image_search: logger.info("🌄 Setting up search for images") # Extract Entries, Generate Image Embeddings - model.image_search = image_search.setup( - config.content_type.image, search_config=config.search_type.image, regenerate=regenerate + content_index.image = image_search.setup( + content_config.image, search_models.image_search.image_encoder, regenerate=regenerate ) - if (t == state.SearchType.Github or t == None) and config.content_type.github and config.search_type.asymmetric: + if (t == state.SearchType.Github or t == None) and content_config.github and search_models.text_search: logger.info("🐙 Setting up search for github") # Extract Entries, Generate Github Embeddings - model.github_search = text_search.setup( + content_index.github = text_search.setup( GithubToJsonl, - config.content_type.github, - search_config=config.search_type.asymmetric, + content_config.github, + search_models.text_search.bi_encoder, regenerate=regenerate, filters=[DateFilter(), WordFilter(), FileFilter()], ) # Initialize External Plugin Search - if (t == None or t in state.SearchType) and config.content_type.plugins: + if (t == None or t in state.SearchType) and content_config.plugins and search_models.text_search: logger.info("🔌 Setting up search for plugins") - model.plugin_search = {} - for plugin_type, plugin_config in config.content_type.plugins.items(): - model.plugin_search[plugin_type] = text_search.setup( + content_index.plugins = {} + for plugin_type, plugin_config in content_config.plugins.items(): + content_index.plugins[plugin_type] = text_search.setup( JsonlToJsonl, plugin_config, - search_config=config.search_type.asymmetric, + search_models.text_search.bi_encoder, regenerate=regenerate, filters=[DateFilter(), WordFilter(), FileFilter()], ) # Initialize Notion Search - if (t == None or t in state.SearchType) and config.content_type.notion: + if (t == None or t in state.SearchType) and content_config.notion and search_models.text_search: logger.info("🔌 Setting up search for notion") - model.notion_search = text_search.setup( + content_index.notion = text_search.setup( NotionToJsonl, - config.content_type.notion, - search_config=config.search_type.asymmetric, + content_config.notion, + search_models.text_search.bi_encoder, regenerate=regenerate, filters=[DateFilter(), WordFilter(), FileFilter()], ) @@ -189,7 +227,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, # Invalidate Query Cache state.query_cache = LRU() - return model + return content_index def configure_processor(processor_config: ProcessorConfig): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index d04284e5..5f3c5cdd 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Header, Request from sentence_transformers import util # Internal Packages -from khoj.configure import configure_processor, configure_search +from khoj.configure import configure_content, configure_processor, configure_search from khoj.search_type import image_search, text_search from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter @@ -102,17 +102,17 @@ if not state.demo: state.config.content_type[content_type] = None if content_type == "github": - state.model.github_search = None + state.content_index.github = None elif content_type == "notion": - state.model.notion_search = None + state.content_index.notion = None elif content_type == "plugins": - state.model.plugin_search = None + state.content_index.plugins = None elif content_type == "pdf": - state.model.pdf_search = None + state.content_index.pdf = None elif content_type == "markdown": - state.model.markdown_search = None + state.content_index.markdown = None elif content_type == "org": - state.model.org_search = None + state.content_index.org = None try: save_config_to_file_updated_state() @@ -182,7 +182,7 @@ def get_config_types(): for search_type in SearchType if ( search_type.value in configured_content_types - and getattr(state.model, f"{search_type.value}_search") is not None + and getattr(state.content_index, search_type.value) is not None ) or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"]) or search_type == SearchType.All @@ -210,7 +210,7 @@ async def search( if q is None or q == "": logger.warning(f"No query param (q) passed in API call to initiate search") return results - if not state.model or not any(state.model.__dict__.values()): + if not state.search_models or not any(state.search_models.__dict__.values()): logger.warning(f"No search models loaded. Configure a search model before initiating search") return results @@ -234,7 +234,7 @@ async def search( encoded_asymmetric_query = None if t == SearchType.All or t != SearchType.Image: text_search_models: List[TextSearchModel] = [ - model for model in state.model.__dict__.values() if isinstance(model, TextSearchModel) + model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel) ] if text_search_models: with timer("Encoding query took", logger=logger): @@ -247,13 +247,14 @@ async def search( ) with concurrent.futures.ThreadPoolExecutor() as executor: - if (t == SearchType.Org or t == SearchType.All) and state.model.org_search: + if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search: # query org-mode notes search_futures += [ executor.submit( text_search.query, user_query, - state.model.org_search, + state.search_models.text_search, + state.content_index.org, question_embedding=encoded_asymmetric_query, rank_results=r or False, score_threshold=score_threshold, @@ -261,13 +262,18 @@ async def search( ) ] - if (t == SearchType.Markdown or t == SearchType.All) and state.model.markdown_search: + if ( + (t == SearchType.Markdown or t == SearchType.All) + and state.content_index.markdown + and state.search_models.text_search + ): # query markdown notes search_futures += [ executor.submit( text_search.query, user_query, - state.model.markdown_search, + state.search_models.text_search, + state.content_index.markdown, question_embedding=encoded_asymmetric_query, rank_results=r or False, score_threshold=score_threshold, @@ -275,13 +281,18 @@ async def search( ) ] - if (t == SearchType.Github or t == SearchType.All) and state.model.github_search: + if ( + (t == SearchType.Github or t == SearchType.All) + and state.content_index.github + and state.search_models.text_search + ): # query github issues search_futures += [ executor.submit( text_search.query, user_query, - state.model.github_search, + state.search_models.text_search, + state.content_index.github, question_embedding=encoded_asymmetric_query, rank_results=r or False, score_threshold=score_threshold, @@ -289,13 +300,14 @@ async def search( ) ] - if (t == SearchType.Pdf or t == SearchType.All) and state.model.pdf_search: + if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search: # query pdf files search_futures += [ executor.submit( text_search.query, user_query, - state.model.pdf_search, + state.search_models.text_search, + state.content_index.pdf, question_embedding=encoded_asymmetric_query, rank_results=r or False, score_threshold=score_threshold, @@ -303,26 +315,38 @@ async def search( ) ] - if (t == SearchType.Image) and state.model.image_search: + if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search: # query images search_futures += [ executor.submit( image_search.query, user_query, results_count, - state.model.image_search, + state.search_models.image_search, + state.content_index.image, score_threshold=score_threshold, ) ] - if (t == SearchType.All or t in SearchType) and state.model.plugin_search: + if ( + (t == SearchType.All or t in SearchType) + and state.content_index.plugins + and state.search_models.plugin_search + ): # query specified plugin type + # Get plugin content, search model for specified search type, or the first one if none specified + plugin_search = state.search_models.plugin_search.get(t.value) or next( + iter(state.search_models.plugin_search.values()) + ) + plugin_content = state.content_index.plugins.get(t.value) or next( + iter(state.content_index.plugins.values()) + ) search_futures += [ executor.submit( text_search.query, user_query, - # Get plugin search model for specified search type, or the first one if none specified - state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())), + plugin_search, + plugin_content, question_embedding=encoded_asymmetric_query, rank_results=r or False, score_threshold=score_threshold, @@ -330,13 +354,18 @@ async def search( ) ] - if (t == SearchType.Notion or t == SearchType.All) and state.model.notion_search: + if ( + (t == SearchType.Notion or t == SearchType.All) + and state.content_index.notion + and state.search_models.text_search + ): # query notion pages search_futures += [ executor.submit( text_search.query, user_query, - state.model.notion_search, + state.search_models.text_search, + state.content_index.notion, question_embedding=encoded_asymmetric_query, rank_results=r or False, score_threshold=score_threshold, @@ -347,13 +376,13 @@ async def search( # Query across each requested content types in parallel with timer("Query took", logger): for search_future in concurrent.futures.as_completed(search_futures): - if t == SearchType.Image: + if t == SearchType.Image and state.content_index.image: hits = await search_future.result() output_directory = constants.web_directory / "images" # Collate results results += image_search.collate_results( hits, - image_names=state.model.image_search.image_names, + image_names=state.content_index.image.image_names, output_directory=output_directory, image_files_url="/static/images", count=results_count, @@ -404,7 +433,12 @@ def update( try: state.search_index_lock.acquire() try: - state.model = configure_search(state.model, state.config, regenerate=force or False, t=t) + if state.config and state.config.search_type: + state.search_models = configure_search(state.search_models, state.config.search_type) + if state.search_models: + state.content_index = configure_content( + state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t + ) except Exception as e: logger.error(e) raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py index d6cc33d6..8b92d9db 100644 --- a/src/khoj/search_type/image_search.py +++ b/src/khoj/search_type/image_search.py @@ -12,10 +12,12 @@ from sentence_transformers import SentenceTransformer, util from PIL import Image from tqdm import trange import torch +from khoj.utils import state # Internal Packages from khoj.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer -from khoj.utils.config import ImageSearchModel +from khoj.utils.config import ImageContent, ImageSearchModel +from khoj.utils.models import BaseEncoder from khoj.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse @@ -40,7 +42,7 @@ def initialize_model(search_config: ImageSearchConfig): model_type=search_config.encoder_type or SentenceTransformer, ) - return encoder + return ImageSearchModel(encoder) def extract_entries(image_directories): @@ -143,7 +145,9 @@ def extract_metadata(image_name): return image_processed_metadata -async def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf): +async def query( + raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf +): # Set query to image content if query is of form file:/path/to/file.png if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True) @@ -158,21 +162,21 @@ async def query(raw_query, count, model: ImageSearchModel, score_threshold: floa # Now we encode the query (which can either be an image or a text string) with timer("Query Encode Time", logger): - query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False) + query_embedding = search_model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False) # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. with timer("Search Time", logger): image_hits = { result["corpus_id"]: {"image_score": result["score"], "score": result["score"]} - for result in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0] + for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0] } # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. - if model.image_metadata_embeddings: + if content.image_metadata_embeddings: with timer("Metadata Search Time", logger): metadata_hits = { result["corpus_id"]: result["score"] - for result in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0] + for result in util.semantic_search(query_embedding, content.image_metadata_embeddings, top_k=count)[0] } # Sum metadata, image scores of the highest ranked images @@ -239,10 +243,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count= return results -def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel: - # Initialize Model - encoder = initialize_model(search_config) - +def setup(config: ImageContentConfig, encoder: BaseEncoder, regenerate: bool) -> ImageContent: # Extract Entries absolute_image_files, filtered_image_files = set(), set() if config.input_directories: @@ -268,4 +269,4 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera use_xmp_metadata=config.use_xmp_metadata, ) - return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder) + return ImageContent(all_image_files, image_embeddings, image_metadata_embeddings) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 09057f9a..a77be6e1 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -13,7 +13,7 @@ from khoj.search_filter.base_filter import BaseFilter # Internal Packages from khoj.utils import state from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer -from khoj.utils.config import TextSearchModel +from khoj.utils.config import TextContent, TextSearchModel from khoj.utils.models import BaseEncoder from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry from khoj.utils.jsonl import load_jsonl @@ -26,9 +26,6 @@ def initialize_model(search_config: TextSearchConfig): "Initialize model for semantic search on text" torch.set_num_threads(4) - # Number of entries we want to retrieve with the bi-encoder - top_k = 15 - # If model directory is configured if search_config.model_directory: # Convert model directory to absolute path @@ -52,7 +49,7 @@ def initialize_model(search_config: TextSearchConfig): device=f"{state.device}", ) - return bi_encoder, cross_encoder, top_k + return TextSearchModel(bi_encoder, cross_encoder) def extract_entries(jsonl_file) -> List[Entry]: @@ -67,7 +64,7 @@ def compute_embeddings( new_entries = [] # Load pre-computed embeddings from file if exists and update them if required if embeddings_file.exists() and not regenerate: - corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device) + corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device) logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}") # Encode any new entries in the corpus and update corpus embeddings @@ -104,17 +101,18 @@ def compute_embeddings( async def query( raw_query: str, - model: TextSearchModel, + search_model: TextSearchModel, + content: TextContent, question_embedding: Union[torch.Tensor, None] = None, rank_results: bool = False, score_threshold: float = -math.inf, dedupe: bool = True, ) -> Tuple[List[dict], List[Entry]]: "Search for entries that answer the query" - query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings + query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings # Filter query, entries and embeddings before semantic search - query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters) + query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters) # If no entries left after filtering, return empty results if entries is None or len(entries) == 0: @@ -127,18 +125,17 @@ async def query( # Encode the query using the bi-encoder if question_embedding is None: with timer("Query Encode Time", logger, state.device): - question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) + question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) question_embedding = util.normalize_embeddings(question_embedding) # Find relevant entries for the query + top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus with timer("Search Time", logger, state.device): - hits = util.semantic_search( - question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score - )[0] + hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0] # Score all retrieved entries using the cross-encoder - if rank_results: - hits = cross_encoder_score(model.cross_encoder, query, entries, hits) + if rank_results and search_model.cross_encoder: + hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits) # Filter results by score threshold hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold] @@ -173,13 +170,10 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse] def setup( text_to_jsonl: Type[TextToJsonl], config: TextConfigBase, - search_config: TextSearchConfig, + bi_encoder: BaseEncoder, regenerate: bool, filters: List[BaseFilter] = [], -) -> TextSearchModel: - # Initialize Model - bi_encoder, cross_encoder, top_k = initialize_model(search_config) - +) -> TextContent: # Map notes in text files to (compressed) JSONL formatted file config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) previous_entries = ( @@ -192,7 +186,6 @@ def setup( if is_none_or_empty(entries): config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()]) raise ValueError(f"No valid entries found in specified files: {config_params}") - top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus # Compute or Load Embeddings config.embeddings_file = resolve_absolute_path(config.embeddings_file) @@ -203,7 +196,7 @@ def setup( for filter in filters: filter.load(entries, regenerate=regenerate) - return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k) + return TextContent(entries, corpus_embeddings, filters) def apply_filters( diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 7887e9cd..6ba8b639 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints from enum import Enum from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union # External Packages import torch @@ -30,42 +30,48 @@ class ProcessorType(str, Enum): Conversation = "conversation" +@dataclass +class TextContent: + entries: List[Entry] + corpus_embeddings: torch.Tensor + filters: List[BaseFilter] + + +@dataclass +class ImageContent: + image_names: List[str] + image_embeddings: torch.Tensor + image_metadata_embeddings: torch.Tensor + + +@dataclass class TextSearchModel: - def __init__( - self, - entries: List[Entry], - corpus_embeddings: torch.Tensor, - bi_encoder: BaseEncoder, - cross_encoder: CrossEncoder, - filters: List[BaseFilter], - top_k, - ): - self.entries = entries - self.corpus_embeddings = corpus_embeddings - self.bi_encoder = bi_encoder - self.cross_encoder = cross_encoder - self.filters = filters - self.top_k = top_k + bi_encoder: BaseEncoder + cross_encoder: Optional[CrossEncoder] = None + top_k: Optional[int] = 15 +@dataclass class ImageSearchModel: - def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder): - self.image_encoder = image_encoder - self.image_names = image_names - self.image_embeddings = image_embeddings - self.image_metadata_embeddings = image_metadata_embeddings - self.image_encoder = image_encoder + image_encoder: BaseEncoder + + +@dataclass +class ContentIndex: + org: Optional[TextContent] = None + markdown: Optional[TextContent] = None + pdf: Optional[TextContent] = None + github: Optional[TextContent] = None + notion: Optional[TextContent] = None + image: Optional[ImageContent] = None + plugins: Optional[Dict[str, TextContent]] = None @dataclass class SearchModels: - org_search: Union[TextSearchModel, None] = None - markdown_search: Union[TextSearchModel, None] = None - pdf_search: Union[TextSearchModel, None] = None - image_search: Union[ImageSearchModel, None] = None - github_search: Union[TextSearchModel, None] = None - notion_search: Union[TextSearchModel, None] = None - plugin_search: Union[Dict[str, TextSearchModel], None] = None + text_search: Optional[TextSearchModel] = None + image_search: Optional[ImageSearchModel] = None + plugin_search: Optional[Dict[str, TextSearchModel]] = None class ConversationProcessorConfigModel: diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 7a6cf378..e37f1909 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -20,7 +20,7 @@ from khoj.utils import constants if TYPE_CHECKING: # External Packages - from sentence_transformers import CrossEncoder + from sentence_transformers import SentenceTransformer, CrossEncoder # Internal Packages from khoj.utils.models import BaseEncoder @@ -64,7 +64,9 @@ def merge_dicts(priority_dict: dict, default_dict: dict): return merged_dict -def load_model(model_name: str, model_type, model_dir=None, device: str = None) -> Union[BaseEncoder, CrossEncoder]: +def load_model( + model_name: str, model_type, model_dir=None, device: str = None +) -> Union[BaseEncoder, SentenceTransformer, CrossEncoder]: "Load model from disk or huggingface" # Construct model path logger = logging.getLogger(__name__) diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 0172dc1f..043576f5 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -119,9 +119,9 @@ class AppConfig(ConfigBase): class FullConfig(ConfigBase): - content_type: Optional[ContentConfig] - search_type: Optional[SearchConfig] - processor: Optional[ProcessorConfig] + content_type: Optional[ContentConfig] = None + search_type: Optional[SearchConfig] = None + processor: Optional[ProcessorConfig] = None app: Optional[AppConfig] = AppConfig(should_log_telemetry=True) diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index d59239d8..89688e15 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -9,13 +9,14 @@ from pathlib import Path # Internal Packages from khoj.utils import config as utils_config -from khoj.utils.config import SearchModels, ProcessorConfigModel +from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel from khoj.utils.helpers import LRU from khoj.utils.rawconfig import FullConfig # Application Global State config = FullConfig() -model = SearchModels() +search_models = SearchModels() +content_index = ContentIndex() processor_config = ProcessorConfigModel() config_file: Path = None verbose: int = 0 From b9fb656657e3bed5b8f28908294ef5eb2fa22078 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 14 Jul 2023 01:19:38 -0700 Subject: [PATCH 2/3] Update Tests to setup both content_index, search_models before testing This is required by the updated structure of Khoj setup - Add content_config pytest fixture, pass bi_encoder from search_models.[text|image]_search --- tests/conftest.py | 50 ++++++++++++++++++++++++++++------ tests/test_client.py | 28 +++++++++++++------ tests/test_image_search.py | 40 +++++++++++++++++++-------- tests/test_text_search.py | 56 +++++++++++++++++++++++++------------- 4 files changed, 126 insertions(+), 48 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index dfb27b8b..a92d33ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ from khoj.main import app from khoj.configure import configure_processor, configure_routes, configure_search_types from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.search_type import image_search, text_search +from khoj.utils.config import ImageContent, SearchModels, TextContent from khoj.utils.helpers import resolve_absolute_path from khoj.utils.rawconfig import ( ContentConfig, @@ -41,35 +42,49 @@ def search_config() -> SearchConfig: encoder="sentence-transformers/all-MiniLM-L6-v2", cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2", model_directory=model_dir / "symmetric/", + encoder_type=None, ) search_config.asymmetric = TextSearchConfig( encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1", cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2", model_directory=model_dir / "asymmetric/", + encoder_type=None, ) search_config.image = ImageSearchConfig( - encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/" + encoder="sentence-transformers/clip-ViT-B-32", + model_directory=model_dir / "image/", + encoder_type=None, ) return search_config @pytest.fixture(scope="session") -def content_config(tmp_path_factory, search_config: SearchConfig): +def search_models(search_config: SearchConfig): + search_models = SearchModels() + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + search_models.image_search = image_search.initialize_model(search_config.image) + + return search_models + + +@pytest.fixture(scope="session") +def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig): content_dir = tmp_path_factory.mktemp("content") # Generate Image Embeddings from Test Images content_config = ContentConfig() content_config.image = ImageContentConfig( + input_filter=None, input_directories=["tests/data/images"], embeddings_file=content_dir.joinpath("image_embeddings.pt"), batch_size=1, use_xmp_metadata=False, ) - image_search.setup(content_config.image, search_config.image, regenerate=False) + image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False) # Generate Notes Embeddings from Test Notes content_config.org = TextContentConfig( @@ -80,7 +95,9 @@ def content_config(tmp_path_factory, search_config: SearchConfig): ) filters = [DateFilter(), WordFilter(), FileFilter()] - text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) + text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters + ) content_config.plugins = { "plugin1": TextContentConfig( @@ -106,7 +123,11 @@ def content_config(tmp_path_factory, search_config: SearchConfig): filters = [DateFilter(), WordFilter(), FileFilter()] text_search.setup( - JsonlToJsonl, content_config.plugins["plugin1"], search_config.asymmetric, regenerate=False, filters=filters + JsonlToJsonl, + content_config.plugins["plugin1"], + search_models.text_search.bi_encoder, + regenerate=False, + filters=filters, ) return content_config @@ -157,8 +178,13 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p # Index Markdown Content for Search filters = [DateFilter(), WordFilter(), FileFilter()] - state.model.markdown_search = text_search.setup( - MarkdownToJsonl, md_content_config.markdown, search_config.asymmetric, regenerate=False, filters=filters + state.search_models.text_search = text_search.initialize_model(search_config.asymmetric) + state.content_index.markdown = text_search.setup( + MarkdownToJsonl, + md_content_config.markdown, + state.search_models.text_search.bi_encoder, + regenerate=False, + filters=filters, ) # Initialize Processor from Config @@ -175,8 +201,14 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor state.SearchType = configure_search_types(state.config) # These lines help us Mock the Search models for these search types - state.model.org_search = {} - state.model.image_search = {} + state.search_models.text_search = text_search.initialize_model(search_config.asymmetric) + state.search_models.image_search = image_search.initialize_model(search_config.image) + state.content_index.org = text_search.setup( + OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False + ) + state.content_index.image = image_search.setup( + content_config.image, state.search_models.image_search, regenerate=False + ) configure_routes(app) return TestClient(app) diff --git a/tests/test_client.py b/tests/test_client.py index 81955f39..d86bdd90 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,7 +11,8 @@ from fastapi.testclient import TestClient from khoj.main import app from khoj.configure import configure_routes, configure_search_types from khoj.utils import state -from khoj.utils.state import model, config +from khoj.utils.config import SearchModels +from khoj.utils.state import search_models, content_index, config from khoj.search_type import text_search, image_search from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl @@ -143,7 +144,10 @@ def test_get_configured_types_with_no_content_config(): # ---------------------------------------------------------------------------------------------------- def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) + search_models.image_search = image_search.initialize_model(search_config.image) + content_index.image = image_search.setup( + content_config.image, search_models.image_search.image_encoder, regenerate=False + ) query_expected_image_pairs = [ ("kitten", "kitten_park.jpg"), ("a horse and dog on a leash", "horse_dog.jpg"), @@ -166,7 +170,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear # ---------------------------------------------------------------------------------------------------- def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.org_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + content_index.org = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False + ) user_query = quote("How to git install application?") # Act @@ -183,8 +190,9 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter(), FileFilter()] - model.org_search = text_search.setup( - OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + content_index.org = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters ) user_query = quote('+"Emacs" file:"*.org"') @@ -202,8 +210,9 @@ def test_notes_search_with_only_filters(client, content_config: ContentConfig, s def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter()] - model.org_search = text_search.setup( - OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + content_index.org = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search, regenerate=False, filters=filters ) user_query = quote('How to git install application? +"Emacs"') @@ -221,8 +230,9 @@ def test_notes_search_with_include_filter(client, content_config: ContentConfig, def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter()] - model.org_search = text_search.setup( - OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + content_index.org = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters ) user_query = quote('How to git install application? -"clone"') diff --git a/tests/test_image_search.py b/tests/test_image_search.py index e4f08d35..82617ab3 100644 --- a/tests/test_image_search.py +++ b/tests/test_image_search.py @@ -5,9 +5,10 @@ from PIL import Image # External Packages import pytest +from khoj.utils.config import SearchModels # Internal Packages -from khoj.utils.state import model +from khoj.utils.state import content_index, search_models from khoj.utils.constants import web_directory from khoj.search_type import image_search from khoj.utils.helpers import resolve_absolute_path @@ -16,10 +17,12 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig # Test # ---------------------------------------------------------------------------------------------------- -def test_image_search_setup(content_config: ContentConfig, search_config: SearchConfig): +def test_image_search_setup(content_config: ContentConfig, search_models: SearchModels): # Act # Regenerate image search embeddings during image setup - image_search_model = image_search.setup(content_config.image, search_config.image, regenerate=True) + image_search_model = image_search.setup( + content_config.image, search_models.image_search.image_encoder, regenerate=True + ) # Assert assert len(image_search_model.image_names) == 3 @@ -54,8 +57,11 @@ def test_image_metadata(content_config: ContentConfig): @pytest.mark.anyio async def test_image_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange + search_models.image_search = image_search.initialize_model(search_config.image) + content_index.image = image_search.setup( + content_config.image, search_models.image_search.image_encoder, regenerate=False + ) output_directory = resolve_absolute_path(web_directory) - model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) query_expected_image_pairs = [ ("kitten", "kitten_park.jpg"), ("horse and dog in a farm", "horse_dog.jpg"), @@ -64,11 +70,13 @@ async def test_image_search(content_config: ContentConfig, search_config: Search # Act for query, expected_image_name in query_expected_image_pairs: - hits = await image_search.query(query, count=1, model=model.image_search) + hits = await image_search.query( + query, count=1, search_model=search_models.image_search, content=content_index.image + ) results = image_search.collate_results( hits, - model.image_search.image_names, + content_index.image.image_names, output_directory=output_directory, image_files_url="/static/images", count=1, @@ -90,7 +98,10 @@ async def test_image_search(content_config: ContentConfig, search_config: Search @pytest.mark.anyio async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog): # Arrange - model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) + search_models.image_search = image_search.initialize_model(search_config.image) + content_index.image = image_search.setup( + content_config.image, search_models.image_search.image_encoder, regenerate=False + ) max_words_supported = 10 query = " ".join(["hello"] * 100) truncated_query = " ".join(["hello"] * max_words_supported) @@ -98,7 +109,9 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc # Act try: with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): - await image_search.query(query, count=1, model=model.image_search) + await image_search.query( + query, count=1, search_model=search_models.image_search, content=content_index.image + ) # Assert except RuntimeError as e: if "The size of tensor a (102) must match the size of tensor b (77)" in str(e): @@ -110,8 +123,11 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc @pytest.mark.anyio async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog): # Arrange + search_models.image_search = image_search.initialize_model(search_config.image) + content_index.image = image_search.setup( + content_config.image, search_models.image_search.image_encoder, regenerate=False + ) output_directory = resolve_absolute_path(web_directory) - model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) image_directory = content_config.image.input_directories[0] query = f"file:{image_directory.joinpath('kitten_park.jpg')}" @@ -119,11 +135,13 @@ async def test_image_search_by_filepath(content_config: ContentConfig, search_co # Act with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): - hits = await image_search.query(query, count=1, model=model.image_search) + hits = await image_search.query( + query, count=1, search_model=search_models.image_search, content=content_index.image + ) results = image_search.collate_results( hits, - model.image_search.image_names, + content_index.image.image_names, output_directory=output_directory, image_files_url="/static/images", count=1, diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 69f58645..c18a4c42 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -5,9 +5,10 @@ import os # External Packages import pytest +from khoj.utils.config import SearchModels # Internal Packages -from khoj.utils.state import model +from khoj.utils.state import content_index, search_models from khoj.search_type import text_search from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl @@ -41,10 +42,12 @@ def test_asymmetric_setup_with_empty_file_raises_error( # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): +def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels): # Act # Regenerate notes embeddings during asymmetric setup - notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True + ) # Assert assert len(notes_model.entries) == 10 @@ -52,18 +55,18 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo # ---------------------------------------------------------------------------------------------------- -def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_config: SearchConfig, caplog): +def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog): # Arrange caplog.set_level(logging.INFO, logger="khoj") # Act # Generate initial notes embeddings during asymmetric setup - text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True) initial_logs = caplog.text caplog.clear() # Clear logs # Run asymmetric setup again with no changes to data source. Ensure index is not updated - text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) + text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False) final_logs = caplog.text # Assert @@ -75,11 +78,16 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi @pytest.mark.anyio async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + content_index.org = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True + ) query = "How to git install application?" # Act - hits, entries = await text_search.query(query, model=model.notes_search, rank_results=True) + hits, entries = await text_search.query( + query, search_model=search_models.text_search, content=content_index.org, rank_results=True + ) results = text_search.collate_results(hits, entries, count=1) @@ -90,7 +98,7 @@ async def test_asymmetric_search(content_config: ContentConfig, search_config: S # ---------------------------------------------------------------------------------------------------- -def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig): +def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels): # Arrange # Insert org-mode entry with size exceeding max token limit to new org file max_tokens = 256 @@ -103,7 +111,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent # Act # reload embeddings, entries, notes model after adding new org-mode file initial_notes_model = text_search.setup( - OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False + OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False ) # Assert @@ -113,9 +121,11 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path): +def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path): # Arrange - initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + initial_notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True + ) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 @@ -127,12 +137,14 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC # regenerate notes jsonl, model embeddings and model to include entry from new file regenerated_notes_model = text_search.setup( - OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True ) # Act # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files - initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False + ) # Assert assert len(regenerated_notes_model.entries) == 11 @@ -149,9 +161,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC # ---------------------------------------------------------------------------------------------------- -def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path): +def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path): # Arrange - initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + initial_notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True + ) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 @@ -163,7 +177,9 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search # Act # update embeddings, entries with the newly added note content_config.org.input_files = [f"{new_org_file}"] - initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False + ) # Assert # verify new entry added in updated embeddings, entries @@ -177,10 +193,12 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search # ---------------------------------------------------------------------------------------------------- @pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set") -def test_asymmetric_setup_github(content_config: ContentConfig, search_config: SearchConfig): +def test_asymmetric_setup_github(content_config: ContentConfig, search_models: SearchModels): # Act # Regenerate github embeddings to test asymmetric setup without caching - github_model = text_search.setup(GithubToJsonl, content_config.github, search_config.asymmetric, regenerate=True) + github_model = text_search.setup( + GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True + ) # Assert assert len(github_model.entries) > 1 From f08e9539f1f4962c404f82a8d9633794df2ecfa9 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 14 Jul 2023 16:57:27 -0700 Subject: [PATCH 3/3] Release lock after updating index even if update fails to prevent deadlock Wrap acquire/release locks in try/catch/finally when updating content index and search models to prevent lock not being released on error and causing a deadlock --- src/khoj/configure.py | 42 +++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index c1f8ff04..18c5ac8a 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -56,18 +56,26 @@ def configure_server(args, required=False): state.processor_config = configure_processor(args.config.processor) # Initialize Search Models from Config - state.search_index_lock.acquire() - state.SearchType = configure_search_types(state.config) - state.search_models = configure_search(state.search_models, state.config.search_type) - state.search_index_lock.release() + try: + state.search_index_lock.acquire() + state.SearchType = configure_search_types(state.config) + state.search_models = configure_search(state.search_models, state.config.search_type) + except Exception as e: + logger.error(f"🚨 Error configuring search models on app load: {e}") + finally: + state.search_index_lock.release() # Initialize Content from Config if state.search_models: - state.search_index_lock.acquire() - state.content_index = configure_content( - state.content_index, state.config.content_type, state.search_models, args.regenerate - ) - state.search_index_lock.release() + try: + state.search_index_lock.acquire() + state.content_index = configure_content( + state.content_index, state.config.content_type, state.search_models, args.regenerate + ) + except Exception as e: + logger.error(f"🚨 Error configuring content index on app load: {e}") + finally: + state.search_index_lock.release() def configure_routes(app): @@ -86,12 +94,16 @@ if not state.demo: @schedule.repeat(schedule.every(61).minutes) def update_search_index(): - state.search_index_lock.acquire() - state.content_index = configure_content( - state.content_index, state.config.content_type, state.search_models, regenerate=False - ) - state.search_index_lock.release() - logger.info("📬 Search index updated via Scheduler") + try: + state.search_index_lock.acquire() + state.content_index = configure_content( + state.content_index, state.config.content_type, state.search_models, regenerate=False + ) + logger.info("📬 Content index updated via Scheduler") + except Exception as e: + logger.error(f"🚨 Error updating content index via Scheduler: {e}") + finally: + state.search_index_lock.release() def configure_search_types(config: FullConfig):