diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 84b87ff3..bba03dc2 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -37,45 +37,63 @@ from khoj.search_filter.file_filter import FileFilter logger = logging.getLogger(__name__) -def configure_server(args, required=False): - if args.config is None: - if required: - logger.error( - f"Exiting as Khoj is not configured.\nConfigure it via http://localhost:42110/config or by editing {state.config_file}." - ) - sys.exit(1) - else: - logger.warning( - f"Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}." - ) - return - else: - state.config = args.config +def initialize_server( + config: Optional[FullConfig], regenerate: bool, type: Optional[SearchType] = None, required=False +): + if config is None and required: + logger.error( + f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://localhost:42110/config or by editing {state.config_file}." + ) + sys.exit(1) + elif config is None: + logger.warning( + f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}." + ) + return None + + try: + configure_server(config, regenerate, type) + except Exception as e: + logger.error(f"🚨 Failed to configure server on app load: {e}", exc_info=True) + + +def configure_server(config: FullConfig, regenerate: bool, search_type: Optional[SearchType] = None): + # Update Config + state.config = config # Initialize Processor from Config - state.processor_config = configure_processor(args.config.processor) + try: + state.config_lock.acquire() + state.processor_config = configure_processor(state.config.processor) + except Exception as e: + logger.error(f"🚨 Failed to configure processor") + raise e + finally: + state.config_lock.release() # Initialize Search Models from Config try: - state.search_index_lock.acquire() + state.config_lock.acquire() state.SearchType = configure_search_types(state.config) state.search_models = configure_search(state.search_models, state.config.search_type) except Exception as e: - logger.error(f"🚨 Error configuring search models on app load: {e}") + logger.error(f"🚨 Failed to configure search models") + raise e finally: - state.search_index_lock.release() + state.config_lock.release() # Initialize Content from Config if state.search_models: try: - state.search_index_lock.acquire() + state.config_lock.acquire() state.content_index = configure_content( - state.content_index, state.config.content_type, state.search_models, args.regenerate + state.content_index, state.config.content_type, state.search_models, regenerate, search_type ) except Exception as e: - logger.error(f"🚨 Error configuring content index on app load: {e}") + logger.error(f"🚨 Failed to index content") + raise e finally: - state.search_index_lock.release() + state.config_lock.release() def configure_routes(app): @@ -95,7 +113,7 @@ if not state.demo: @schedule.repeat(schedule.every(61).minutes) def update_search_index(): try: - state.search_index_lock.acquire() + state.config_lock.acquire() state.content_index = configure_content( state.content_index, state.config.content_type, state.search_models, regenerate=False ) @@ -103,7 +121,7 @@ if not state.demo: except Exception as e: logger.error(f"🚨 Error updating content index via Scheduler: {e}") finally: - state.search_index_lock.release() + state.config_lock.release() def configure_search_types(config: FullConfig): @@ -118,10 +136,10 @@ def configure_search_types(config: FullConfig): return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types)) -def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]: +def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]: # Run Validation Checks if search_config is None: - logger.warning("🚨 No Search type is configured.") + logger.warning("🚨 No Search configuration available.") return None if search_models is None: search_models = SearchModels() @@ -147,7 +165,7 @@ def configure_content( ) -> Optional[ContentIndex]: # Run Validation Checks if content_config is None: - logger.warning("🚨 No Content type is configured.") + logger.warning("🚨 No Content configuration available.") return None if content_index is None: content_index = ContentIndex() @@ -242,9 +260,10 @@ def configure_content( return content_index -def configure_processor(processor_config: ProcessorConfig): +def configure_processor(processor_config: Optional[ProcessorConfig]): if not processor_config: - return + logger.warning("🚨 No Processor configuration available.") + return None processor = ProcessorConfigModel() diff --git a/src/khoj/main.py b/src/khoj/main.py index 9f7fe8df..6689805c 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -25,7 +25,7 @@ from rich.logging import RichHandler import schedule # Internal Packages -from khoj.configure import configure_routes, configure_server +from khoj.configure import configure_routes, initialize_server from khoj.utils import state from khoj.utils.cli import cli @@ -70,7 +70,7 @@ def run(): poll_task_scheduler() # Start Server - configure_server(args, required=False) + initialize_server(args.config, args.regenerate, required=False) configure_routes(app) start_server(app, host=args.host, port=args.port, socket=args.socket) else: @@ -93,7 +93,7 @@ def run(): tray.show() # Setup Server - configure_server(args, required=False) + initialize_server(args.config, args.regenerate, required=False) configure_routes(app) server = ServerThread(start_server_func=lambda: start_server(app, host=args.host, port=args.port)) diff --git a/src/khoj/processor/github/github_to_jsonl.py b/src/khoj/processor/github/github_to_jsonl.py index dd797c31..91dbd6da 100644 --- a/src/khoj/processor/github/github_to_jsonl.py +++ b/src/khoj/processor/github/github_to_jsonl.py @@ -13,9 +13,8 @@ from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.text_to_jsonl import TextToJsonl -from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data +from khoj.utils.jsonl import compress_jsonl_data from khoj.utils.rawconfig import Entry -from khoj.utils import state logger = logging.getLogger(__name__) @@ -38,7 +37,7 @@ class GithubToJsonl(TextToJsonl): else: return - def process(self, previous_entries=None): + def process(self, previous_entries=[]): current_entries = [] for repo in self.config.repos: current_entries += self.process_repo(repo) @@ -98,10 +97,7 @@ class GithubToJsonl(TextToJsonl): jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries) # Compress JSONL formatted Data - if self.config.compressed_jsonl.suffix == ".gz": - compress_jsonl_data(jsonl_data, self.config.compressed_jsonl) - elif self.config.compressed_jsonl.suffix == ".jsonl": - dump_jsonl(jsonl_data, self.config.compressed_jsonl) + compress_jsonl_data(jsonl_data, self.config.compressed_jsonl) return entries_with_ids diff --git a/src/khoj/processor/jsonl/jsonl_to_jsonl.py b/src/khoj/processor/jsonl/jsonl_to_jsonl.py index f743d5d5..3c824545 100644 --- a/src/khoj/processor/jsonl/jsonl_to_jsonl.py +++ b/src/khoj/processor/jsonl/jsonl_to_jsonl.py @@ -7,7 +7,7 @@ from typing import List # Internal Packages from khoj.processor.text_to_jsonl import TextToJsonl from khoj.utils.helpers import get_absolute_path, timer -from khoj.utils.jsonl import load_jsonl, dump_jsonl, compress_jsonl_data +from khoj.utils.jsonl import load_jsonl, compress_jsonl_data from khoj.utils.rawconfig import Entry @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) class JsonlToJsonl(TextToJsonl): # Define Functions - def process(self, previous_entries=None): + def process(self, previous_entries=[]): # Extract required fields from config input_jsonl_files, input_jsonl_filter, output_file = ( self.config.input_files, @@ -38,15 +38,9 @@ class JsonlToJsonl(TextToJsonl): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): - if not previous_entries: - entries_with_ids = list(enumerate(current_entries)) - else: - entries_with_ids = TextToJsonl.mark_entries_for_update( - current_entries, - previous_entries, - key="compiled", - logger=logger, - ) + entries_with_ids = TextToJsonl.mark_entries_for_update( + current_entries, previous_entries, key="compiled", logger=logger + ) with timer("Write entries to JSONL file", logger): # Process Each Entry from All Notes Files @@ -54,10 +48,7 @@ class JsonlToJsonl(TextToJsonl): jsonl_data = JsonlToJsonl.convert_entries_to_jsonl(entries) # Compress JSONL formatted Data - if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file) - elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file) + compress_jsonl_data(jsonl_data, output_file) return entries_with_ids diff --git a/src/khoj/processor/markdown/markdown_to_jsonl.py b/src/khoj/processor/markdown/markdown_to_jsonl.py index 21cbda72..b6acbfbb 100644 --- a/src/khoj/processor/markdown/markdown_to_jsonl.py +++ b/src/khoj/processor/markdown/markdown_to_jsonl.py @@ -10,7 +10,7 @@ from typing import List from khoj.processor.text_to_jsonl import TextToJsonl from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer from khoj.utils.constants import empty_escape_sequences -from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data +from khoj.utils.jsonl import compress_jsonl_data from khoj.utils.rawconfig import Entry, TextContentConfig @@ -23,7 +23,7 @@ class MarkdownToJsonl(TextToJsonl): self.config = config # Define Functions - def process(self, previous_entries=None): + def process(self, previous_entries=[]): # Extract required fields from config markdown_files, markdown_file_filter, output_file = ( self.config.input_files, @@ -51,12 +51,9 @@ class MarkdownToJsonl(TextToJsonl): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): - if not previous_entries: - entries_with_ids = list(enumerate(current_entries)) - else: - entries_with_ids = TextToJsonl.mark_entries_for_update( - current_entries, previous_entries, key="compiled", logger=logger - ) + entries_with_ids = TextToJsonl.mark_entries_for_update( + current_entries, previous_entries, key="compiled", logger=logger + ) with timer("Write markdown entries to JSONL file", logger): # Process Each Entry from All Notes Files @@ -64,10 +61,7 @@ class MarkdownToJsonl(TextToJsonl): jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries) # Compress JSONL formatted Data - if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file) - elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file) + compress_jsonl_data(jsonl_data, output_file) return entries_with_ids diff --git a/src/khoj/processor/notion/notion_to_jsonl.py b/src/khoj/processor/notion/notion_to_jsonl.py index 20a11cd7..489f0341 100644 --- a/src/khoj/processor/notion/notion_to_jsonl.py +++ b/src/khoj/processor/notion/notion_to_jsonl.py @@ -8,7 +8,7 @@ import requests from khoj.utils.helpers import timer from khoj.utils.rawconfig import Entry, NotionContentConfig from khoj.processor.text_to_jsonl import TextToJsonl -from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data +from khoj.utils.jsonl import compress_jsonl_data from khoj.utils.rawconfig import Entry from enum import Enum @@ -80,7 +80,7 @@ class NotionToJsonl(TextToJsonl): self.body_params = {"page_size": 100} - def process(self, previous_entries=None): + def process(self, previous_entries=[]): current_entries = [] # Get all pages @@ -240,12 +240,9 @@ class NotionToJsonl(TextToJsonl): def update_entries_with_ids(self, current_entries, previous_entries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): - if not previous_entries: - entries_with_ids = list(enumerate(current_entries)) - else: - entries_with_ids = TextToJsonl.mark_entries_for_update( - current_entries, previous_entries, key="compiled", logger=logger - ) + entries_with_ids = TextToJsonl.mark_entries_for_update( + current_entries, previous_entries, key="compiled", logger=logger + ) with timer("Write Notion entries to JSONL file", logger): # Process Each Entry from all Notion entries @@ -253,9 +250,6 @@ class NotionToJsonl(TextToJsonl): jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries) # Compress JSONL formatted Data - if self.config.compressed_jsonl.suffix == ".gz": - compress_jsonl_data(jsonl_data, self.config.compressed_jsonl) - elif self.config.compressed_jsonl.suffix == ".jsonl": - dump_jsonl(jsonl_data, self.config.compressed_jsonl) + compress_jsonl_data(jsonl_data, self.config.compressed_jsonl) return entries_with_ids diff --git a/src/khoj/processor/org_mode/org_to_jsonl.py b/src/khoj/processor/org_mode/org_to_jsonl.py index 664427d9..b3bc06fd 100644 --- a/src/khoj/processor/org_mode/org_to_jsonl.py +++ b/src/khoj/processor/org_mode/org_to_jsonl.py @@ -8,7 +8,7 @@ from typing import Iterable, List from khoj.processor.org_mode import orgnode from khoj.processor.text_to_jsonl import TextToJsonl from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer -from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data +from khoj.utils.jsonl import compress_jsonl_data from khoj.utils.rawconfig import Entry, TextContentConfig from khoj.utils import state @@ -22,7 +22,7 @@ class OrgToJsonl(TextToJsonl): self.config = config # Define Functions - def process(self, previous_entries: List[Entry] = None): + def process(self, previous_entries: List[Entry] = []): # Extract required fields from config org_files, org_file_filter, output_file = ( self.config.input_files, @@ -51,9 +51,7 @@ class OrgToJsonl(TextToJsonl): current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) # Identify, mark and merge any new entries with previous entries - if not previous_entries: - entries_with_ids = list(enumerate(current_entries)) - else: + with timer("Identify new or updated entries", logger): entries_with_ids = TextToJsonl.mark_entries_for_update( current_entries, previous_entries, key="compiled", logger=logger ) @@ -64,10 +62,7 @@ class OrgToJsonl(TextToJsonl): jsonl_data = self.convert_org_entries_to_jsonl(entries) # Compress JSONL formatted Data - if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file) - elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file) + compress_jsonl_data(jsonl_data, output_file) return entries_with_ids @@ -125,9 +120,13 @@ class OrgToJsonl(TextToJsonl): # Ignore title notes i.e notes with just headings and empty body continue + todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else "" # Prepend filename as top heading to entry filename = Path(entry_to_file_map[parsed_entry]).stem - heading = f"* {filename}\n** {parsed_entry.heading}." if parsed_entry.heading else f"* {filename}." + if parsed_entry.heading: + heading = f"* {filename}\n** {todo_str}{parsed_entry.heading}." + else: + heading = f"* {filename}." compiled = heading if state.verbose > 2: diff --git a/src/khoj/processor/pdf/pdf_to_jsonl.py b/src/khoj/processor/pdf/pdf_to_jsonl.py index c5c395bc..f8a20692 100644 --- a/src/khoj/processor/pdf/pdf_to_jsonl.py +++ b/src/khoj/processor/pdf/pdf_to_jsonl.py @@ -10,7 +10,7 @@ from langchain.document_loaders import PyPDFLoader # Internal Packages from khoj.processor.text_to_jsonl import TextToJsonl from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer -from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data +from khoj.utils.jsonl import compress_jsonl_data from khoj.utils.rawconfig import Entry @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) class PdfToJsonl(TextToJsonl): # Define Functions - def process(self, previous_entries=None): + def process(self, previous_entries=[]): # Extract required fields from config pdf_files, pdf_file_filter, output_file = ( self.config.input_files, @@ -45,12 +45,9 @@ class PdfToJsonl(TextToJsonl): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): - if not previous_entries: - entries_with_ids = list(enumerate(current_entries)) - else: - entries_with_ids = TextToJsonl.mark_entries_for_update( - current_entries, previous_entries, key="compiled", logger=logger - ) + entries_with_ids = TextToJsonl.mark_entries_for_update( + current_entries, previous_entries, key="compiled", logger=logger + ) with timer("Write PDF entries to JSONL file", logger): # Process Each Entry from All Notes Files @@ -58,10 +55,7 @@ class PdfToJsonl(TextToJsonl): jsonl_data = PdfToJsonl.convert_pdf_maps_to_jsonl(entries) # Compress JSONL formatted Data - if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file) - elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file) + compress_jsonl_data(jsonl_data, output_file) return entries_with_ids diff --git a/src/khoj/processor/text_to_jsonl.py b/src/khoj/processor/text_to_jsonl.py index a4d01cf5..f92ab7b1 100644 --- a/src/khoj/processor/text_to_jsonl.py +++ b/src/khoj/processor/text_to_jsonl.py @@ -17,7 +17,7 @@ class TextToJsonl(ABC): self.config = config @abstractmethod - def process(self, previous_entries: List[Entry] = None) -> List[Tuple[int, Entry]]: + def process(self, previous_entries: List[Entry] = []) -> List[Tuple[int, Entry]]: ... @staticmethod @@ -78,16 +78,23 @@ class TextToJsonl(ABC): # All entries that exist in both current and previous sets are kept existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes) + # load new entries in the order in which they are processed for a stable sort + new_entries = [ + (current_entry_hashes.index(entry_hash), hash_to_current_entries[entry_hash]) + for entry_hash in new_entry_hashes + ] + new_entries_sorted = sorted(new_entries, key=lambda e: e[0]) # Mark new entries with -1 id to flag for later embeddings generation - new_entries = [(-1, hash_to_current_entries[entry_hash]) for entry_hash in new_entry_hashes] + new_entries_sorted = [(-1, entry[1]) for entry in new_entries_sorted] + # Set id of existing entries to their previous ids to reuse their existing encoded embeddings existing_entries = [ (previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash]) for entry_hash in existing_entry_hashes ] - existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0]) - entries_with_ids = existing_entries_sorted + new_entries + + entries_with_ids = existing_entries_sorted + new_entries_sorted return entries_with_ids diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 50e8e1f2..834e8997 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -5,20 +5,20 @@ import time import yaml import logging import json -from typing import List, Optional, Union +from typing import Iterable, List, Optional, Union # External Packages from fastapi import APIRouter, HTTPException, Header, Request from sentence_transformers import util # Internal Packages -from khoj.configure import configure_content, configure_processor, configure_search +from khoj.configure import configure_processor, configure_server from khoj.search_type import image_search, text_search from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter from khoj.utils.config import TextSearchModel -from khoj.utils.helpers import log_telemetry, timer +from khoj.utils.helpers import timer from khoj.utils.rawconfig import ( ContentConfig, FullConfig, @@ -524,34 +524,26 @@ def update( referer: Optional[str] = Header(None), host: Optional[str] = Header(None), ): + if not state.config: + error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}." + logger.warning(error_msg) + raise HTTPException(status_code=500, detail=error_msg) try: - state.search_index_lock.acquire() - try: - if state.config and state.config.search_type: - state.search_models = configure_search(state.search_models, state.config.search_type) - if state.search_models: - state.content_index = configure_content( - state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t - ) - except Exception as e: - logger.error(e) - raise HTTPException(status_code=500, detail=str(e)) - finally: - state.search_index_lock.release() - except ValueError as e: - logger.error(e) - raise HTTPException(status_code=500, detail=str(e)) + configure_server(state.config, regenerate=force or False, search_type=t) + except Exception as e: + error_msg = f"🚨 Failed to update server via API: {e}" + logger.error(error_msg, exc_info=True) + raise HTTPException(status_code=500, detail=error_msg) else: - logger.info("📬 Search index updated via API") - - try: - if state.config and state.config.processor: - state.processor_config = configure_processor(state.config.processor) - except ValueError as e: - logger.error(e) - raise HTTPException(status_code=500, detail=str(e)) - else: - logger.info("📬 Processor reconfigured via API") + components = [] + if state.search_models: + components.append("Search models") + if state.content_index: + components.append("Content index") + if state.processor_config: + components.append("Conversation processor") + components_msg = ", ".join(components) + logger.info(f"📬 {components_msg} updated via API") update_telemetry_state( request=request, diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index a77be6e1..09174186 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -58,43 +58,48 @@ def extract_entries(jsonl_file) -> List[Entry]: def compute_embeddings( - entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False + entries_with_ids: List[Tuple[int, Entry]], + bi_encoder: BaseEncoder, + embeddings_file: Path, + regenerate=False, + normalize=True, ): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" - new_entries = [] + new_embeddings = torch.tensor([], device=state.device) + existing_embeddings = torch.tensor([], device=state.device) + create_index_msg = "" # Load pre-computed embeddings from file if exists and update them if required if embeddings_file.exists() and not regenerate: corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device) logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}") - - # Encode any new entries in the corpus and update corpus embeddings - new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1] - if new_entries: - logger.info(f"📩 Indexing {len(new_entries)} text entries.") - new_embeddings = bi_encoder.encode( - new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True - ) - existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] - if existing_entry_ids: - existing_embeddings = torch.index_select( - corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device) - ) - else: - existing_embeddings = torch.tensor([], device=state.device) - corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) - # Else compute the corpus embeddings from scratch else: - new_entries = [entry.compiled for _, entry in entries_with_ids] - logger.info(f"📩 Indexing {len(new_entries)} text entries. Creating index from scratch.") - corpus_embeddings = bi_encoder.encode( + corpus_embeddings = torch.tensor([], device=state.device) + create_index_msg = " Creating index from scratch." + + # Encode any new entries in the corpus and update corpus embeddings + new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1] + if new_entries: + logger.info(f"📩 Indexing {len(new_entries)} text entries.{create_index_msg}") + new_embeddings = bi_encoder.encode( new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True ) - # Save regenerated or updated embeddings to file - if new_entries: + # Extract existing embeddings from previous corpus embeddings + existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] + if existing_entry_ids: + existing_embeddings = torch.index_select( + corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device) + ) + + # Set corpus embeddings to merger of existing and new embeddings + corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) + if normalize: + # Normalize embeddings for faster lookup via dot product when querying corpus_embeddings = util.normalize_embeddings(corpus_embeddings) - torch.save(corpus_embeddings, embeddings_file) - logger.info(f"📩 Saved computed text embeddings to {embeddings_file}") + + # Save regenerated or updated embeddings to file + torch.save(corpus_embeddings, embeddings_file) + logger.info(f"📩 Saved computed text embeddings to {embeddings_file}") return corpus_embeddings @@ -173,13 +178,14 @@ def setup( bi_encoder: BaseEncoder, regenerate: bool, filters: List[BaseFilter] = [], + normalize: bool = True, ) -> TextContent: # Map notes in text files to (compressed) JSONL formatted file config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) - previous_entries = ( - extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None - ) - entries_with_indices = text_to_jsonl(config).process(previous_entries or []) + previous_entries = [] + if config.compressed_jsonl.exists() and not regenerate: + previous_entries = extract_entries(config.compressed_jsonl) + entries_with_indices = text_to_jsonl(config).process(previous_entries) # Extract Updated Entries entries = extract_entries(config.compressed_jsonl) @@ -190,7 +196,7 @@ def setup( # Compute or Load Embeddings config.embeddings_file = resolve_absolute_path(config.embeddings_file) corpus_embeddings = compute_embeddings( - entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate + entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize ) for filter in filters: diff --git a/src/khoj/utils/jsonl.py b/src/khoj/utils/jsonl.py index c9576810..ed779e79 100644 --- a/src/khoj/utils/jsonl.py +++ b/src/khoj/utils/jsonl.py @@ -20,7 +20,7 @@ def load_jsonl(input_path): # Open JSONL file if input_path.suffix == ".gz": jsonl_file = gzip.open(get_absolute_path(input_path), "rt", encoding="utf-8") - elif input_path.suffix == ".jsonl": + else: jsonl_file = open(get_absolute_path(input_path), "r", encoding="utf-8") # Read JSONL file @@ -36,17 +36,6 @@ def load_jsonl(input_path): return data -def dump_jsonl(jsonl_data, output_path): - "Write List of JSON objects to JSON line file" - # Create output directory, if it doesn't exist - output_path.parent.mkdir(parents=True, exist_ok=True) - - with open(output_path, "w", encoding="utf-8") as f: - f.write(jsonl_data) - - logger.debug(f"Wrote jsonl data to {output_path}") - - def compress_jsonl_data(jsonl_data, output_path): # Create output directory, if it doesn't exist output_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 89688e15..40b3daae 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -24,7 +24,7 @@ host: str = None port: int = None cli_args: List[str] = None query_cache = LRU() -search_index_lock = threading.Lock() +config_lock = threading.Lock() SearchType = utils_config.SearchType telemetry: List[Dict[str, str]] = [] previous_query: str = None diff --git a/tests/conftest.py b/tests/conftest.py index a92d33ca..07c5156f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -90,7 +90,7 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config: content_config.org = TextContentConfig( input_files=None, input_filter=["tests/data/org/*.org"], - compressed_jsonl=content_dir.joinpath("notes.jsonl"), + compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"), embeddings_file=content_dir.joinpath("note_embeddings.pt"), ) @@ -101,7 +101,7 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config: content_config.plugins = { "plugin1": TextContentConfig( - input_files=[content_dir.joinpath("notes.jsonl")], + input_files=[content_dir.joinpath("notes.jsonl.gz")], input_filter=None, compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"), embeddings_file=content_dir.joinpath("plugin_embeddings.pt"), @@ -142,7 +142,7 @@ def md_content_config(tmp_path_factory): content_config.markdown = TextContentConfig( input_files=None, input_filter=["tests/data/markdown/*.markdown"], - compressed_jsonl=content_dir.joinpath("markdown.jsonl"), + compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"), embeddings_file=content_dir.joinpath("markdown_embeddings.pt"), ) diff --git a/tests/test_text_search.py b/tests/test_text_search.py index c18a4c42..1ae7e770 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -5,6 +5,7 @@ import os # External Packages import pytest +import torch from khoj.utils.config import SearchModels # Internal Packages @@ -17,7 +18,7 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl # Test # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_setup_with_missing_file_raises_error( +def test_text_search_setup_with_missing_file_raises_error( org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig ): # Arrange @@ -32,7 +33,7 @@ def test_asymmetric_setup_with_missing_file_raises_error( # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_setup_with_empty_file_raises_error( +def test_text_search_setup_with_empty_file_raises_error( org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig ): # Act @@ -42,7 +43,7 @@ def test_asymmetric_setup_with_empty_file_raises_error( # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels): +def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels): # Act # Regenerate notes embeddings during asymmetric setup notes_model = text_search.setup( @@ -55,7 +56,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchMo # ---------------------------------------------------------------------------------------------------- -def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog): +def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog): # Arrange caplog.set_level(logging.INFO, logger="khoj") @@ -70,8 +71,8 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi final_logs = caplog.text # Assert - assert "📩 Saved computed text embeddings to" in initial_logs - assert "📩 Saved computed text embeddings to" not in final_logs + assert "Creating index from scratch." in initial_logs + assert "Creating index from scratch." not in final_logs # ---------------------------------------------------------------------------------------------------- @@ -121,7 +122,9 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path): +def test_regenerate_index_with_new_entry( + content_config: ContentConfig, search_models: SearchModels, new_org_file: Path +): # Arrange initial_notes_model = text_search.setup( OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True @@ -135,25 +138,20 @@ def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchM with open(new_org_file, "w") as f: f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") + # Act # regenerate notes jsonl, model embeddings and model to include entry from new file regenerated_notes_model = text_search.setup( OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True ) - # Act - # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files - initial_notes_model = text_search.setup( - OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False - ) - # Assert assert len(regenerated_notes_model.entries) == 11 assert len(regenerated_notes_model.corpus_embeddings) == 11 - # Assert - # verify new entry loaded from updated embeddings, entries - assert len(initial_notes_model.entries) == 11 - assert len(initial_notes_model.corpus_embeddings) == 11 + # verify new entry appended to index, without disrupting order or content of existing entries + error_details = compare_index(initial_notes_model, regenerated_notes_model) + if error_details: + pytest.fail(error_details, False) # Cleanup # reset input_files in config to empty list @@ -161,30 +159,101 @@ def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchM # ---------------------------------------------------------------------------------------------------- -def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path): +def test_update_index_with_duplicate_entries_in_stable_order( + org_config_with_only_new_file: TextContentConfig, search_models: SearchModels +): # Arrange - initial_notes_model = text_search.setup( - OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True + new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) + + # Insert org-mode entries with same compiled form into new org file + new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" + with open(new_file_to_index, "w") as f: + f.write(f"{new_entry}{new_entry}") + + # Act + # load embeddings, entries, notes model after adding new org-mode file + initial_index = text_search.setup( + OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True ) - assert len(initial_notes_model.entries) == 10 - assert len(initial_notes_model.corpus_embeddings) == 10 + # update embeddings, entries, notes model after adding new org-mode file + updated_index = text_search.setup( + OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False + ) + + # Assert + # verify only 1 entry added even if there are multiple duplicate entries + assert len(initial_index.entries) == len(updated_index.entries) == 1 + assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1 + + # verify the same entry is added even when there are multiple duplicate entries + error_details = compare_index(initial_index, updated_index) + if error_details: + pytest.fail(error_details) + + +# ---------------------------------------------------------------------------------------------------- +def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels): + # Arrange + new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) + + # Insert org-mode entries with same compiled form into new org file + new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" + with open(new_file_to_index, "w") as f: + f.write(f"{new_entry}{new_entry} -- Tatooine") + + # load embeddings, entries, notes model after adding new org file with 2 entries + initial_index = text_search.setup( + OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True + ) + + # update embeddings, entries, notes model after removing an entry from the org file + with open(new_file_to_index, "w") as f: + f.write(f"{new_entry}") + + # Act + updated_index = text_search.setup( + OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False + ) + + # Assert + # verify only 1 entry added even if there are multiple duplicate entries + assert len(initial_index.entries) == len(updated_index.entries) + 1 + assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1 + + # verify the same entry is added even when there are multiple duplicate entries + error_details = compare_index(updated_index, initial_index) + if error_details: + pytest.fail(error_details) + + +# ---------------------------------------------------------------------------------------------------- +def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path): + # Arrange + initial_notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False + ) # append org-mode entry to first org input file in config with open(new_org_file, "w") as f: - f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") + new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" + f.write(new_entry) # Act # update embeddings, entries with the newly added note content_config.org.input_files = [f"{new_org_file}"] - initial_notes_model = text_search.setup( - OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False + final_notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False ) # Assert - # verify new entry added in updated embeddings, entries - assert len(initial_notes_model.entries) == 11 - assert len(initial_notes_model.corpus_embeddings) == 11 + assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1 + assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1 + + # verify new entry appended to index, without disrupting order or content of existing entries + error_details = compare_index(initial_notes_model, final_notes_model) + if error_details: + pytest.fail(error_details, False) # Cleanup # reset input_files in config to empty list @@ -202,3 +271,25 @@ def test_asymmetric_setup_github(content_config: ContentConfig, search_models: S # Assert assert len(github_model.entries) > 1 + + +def compare_index(initial_notes_model, final_notes_model): + mismatched_entries, mismatched_embeddings = [], [] + for index in range(len(initial_notes_model.entries)): + if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json(): + mismatched_entries.append(index) + + # verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings + for index in range(len(initial_notes_model.corpus_embeddings)): + if not torch.equal(final_notes_model.corpus_embeddings[index], initial_notes_model.corpus_embeddings[index]): + mismatched_embeddings.append(index) + + error_details = "" + if mismatched_entries: + mismatched_entries_str = ",".join(map(str, mismatched_entries)) + error_details += f"Entries at {mismatched_entries_str} not equal\n" + if mismatched_embeddings: + mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings)) + error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n" + + return error_details