Merge pull request #325 from khoj-ai/stablize-simplify-content-indexing

## Stabilize and Simplify Content Indexing

### Major Updates
- 9bcca43 Unify logic to update entries when indexing from scratch or incrementally
- 89c7819 Unify logic to update embeddings when indexing from scratch or incrementally
- 6a0297c Stable sort new entries when marking entries for update
- 58d86d7 Unify logic to configure server from API or on server start
- Create tests to ensure old entries, embeddings in index are unaffected on adding new entries
  - Refer: 1482fd4, 7669b85, 88d1a29 
  - ad41ef3 Make normalization of embeddings configurable to test this in c73feeb

### Minor Updates
- 1673bb5 Add todo state to compiled form of each entry
- 6e70b91 Remove unused `dump_jsonl` helper method 
- 7ad9603 Improve naming of lock
- b02323a Improve naming text search test methods

Resolves #190
This commit is contained in:
Debanjum 2023-07-17 14:51:10 -07:00 committed by GitHub
commit d00c5da8b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 280 additions and 208 deletions

View file

@ -37,45 +37,63 @@ from khoj.search_filter.file_filter import FileFilter
logger = logging.getLogger(__name__)
def configure_server(args, required=False):
if args.config is None:
if required:
def initialize_server(
config: Optional[FullConfig], regenerate: bool, type: Optional[SearchType] = None, required=False
):
if config is None and required:
logger.error(
f"Exiting as Khoj is not configured.\nConfigure it via http://localhost:42110/config or by editing {state.config_file}."
f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://localhost:42110/config or by editing {state.config_file}."
)
sys.exit(1)
else:
elif config is None:
logger.warning(
f"Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
)
return
else:
state.config = args.config
return None
try:
configure_server(config, regenerate, type)
except Exception as e:
logger.error(f"🚨 Failed to configure server on app load: {e}", exc_info=True)
def configure_server(config: FullConfig, regenerate: bool, search_type: Optional[SearchType] = None):
# Update Config
state.config = config
# Initialize Processor from Config
state.processor_config = configure_processor(args.config.processor)
try:
state.config_lock.acquire()
state.processor_config = configure_processor(state.config.processor)
except Exception as e:
logger.error(f"🚨 Failed to configure processor")
raise e
finally:
state.config_lock.release()
# Initialize Search Models from Config
try:
state.search_index_lock.acquire()
state.config_lock.acquire()
state.SearchType = configure_search_types(state.config)
state.search_models = configure_search(state.search_models, state.config.search_type)
except Exception as e:
logger.error(f"🚨 Error configuring search models on app load: {e}")
logger.error(f"🚨 Failed to configure search models")
raise e
finally:
state.search_index_lock.release()
state.config_lock.release()
# Initialize Content from Config
if state.search_models:
try:
state.search_index_lock.acquire()
state.config_lock.acquire()
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, args.regenerate
state.content_index, state.config.content_type, state.search_models, regenerate, search_type
)
except Exception as e:
logger.error(f"🚨 Error configuring content index on app load: {e}")
logger.error(f"🚨 Failed to index content")
raise e
finally:
state.search_index_lock.release()
state.config_lock.release()
def configure_routes(app):
@ -95,7 +113,7 @@ if not state.demo:
@schedule.repeat(schedule.every(61).minutes)
def update_search_index():
try:
state.search_index_lock.acquire()
state.config_lock.acquire()
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate=False
)
@ -103,7 +121,7 @@ if not state.demo:
except Exception as e:
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
finally:
state.search_index_lock.release()
state.config_lock.release()
def configure_search_types(config: FullConfig):
@ -118,10 +136,10 @@ def configure_search_types(config: FullConfig):
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]:
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
# Run Validation Checks
if search_config is None:
logger.warning("🚨 No Search type is configured.")
logger.warning("🚨 No Search configuration available.")
return None
if search_models is None:
search_models = SearchModels()
@ -147,7 +165,7 @@ def configure_content(
) -> Optional[ContentIndex]:
# Run Validation Checks
if content_config is None:
logger.warning("🚨 No Content type is configured.")
logger.warning("🚨 No Content configuration available.")
return None
if content_index is None:
content_index = ContentIndex()
@ -242,9 +260,10 @@ def configure_content(
return content_index
def configure_processor(processor_config: ProcessorConfig):
def configure_processor(processor_config: Optional[ProcessorConfig]):
if not processor_config:
return
logger.warning("🚨 No Processor configuration available.")
return None
processor = ProcessorConfigModel()

View file

@ -25,7 +25,7 @@ from rich.logging import RichHandler
import schedule
# Internal Packages
from khoj.configure import configure_routes, configure_server
from khoj.configure import configure_routes, initialize_server
from khoj.utils import state
from khoj.utils.cli import cli
@ -70,7 +70,7 @@ def run():
poll_task_scheduler()
# Start Server
configure_server(args, required=False)
initialize_server(args.config, args.regenerate, required=False)
configure_routes(app)
start_server(app, host=args.host, port=args.port, socket=args.socket)
else:
@ -93,7 +93,7 @@ def run():
tray.show()
# Setup Server
configure_server(args, required=False)
initialize_server(args.config, args.regenerate, required=False)
configure_routes(app)
server = ServerThread(start_server_func=lambda: start_server(app, host=args.host, port=args.port))

View file

@ -13,9 +13,8 @@ from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry
from khoj.utils import state
logger = logging.getLogger(__name__)
@ -38,7 +37,7 @@ class GithubToJsonl(TextToJsonl):
else:
return
def process(self, previous_entries=None):
def process(self, previous_entries=[]):
current_entries = []
for repo in self.config.repos:
current_entries += self.process_repo(repo)
@ -98,10 +97,7 @@ class GithubToJsonl(TextToJsonl):
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
# Compress JSONL formatted Data
if self.config.compressed_jsonl.suffix == ".gz":
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
elif self.config.compressed_jsonl.suffix == ".jsonl":
dump_jsonl(jsonl_data, self.config.compressed_jsonl)
return entries_with_ids

View file

@ -7,7 +7,7 @@ from typing import List
# Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.helpers import get_absolute_path, timer
from khoj.utils.jsonl import load_jsonl, dump_jsonl, compress_jsonl_data
from khoj.utils.jsonl import load_jsonl, compress_jsonl_data
from khoj.utils.rawconfig import Entry
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
class JsonlToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=None):
def process(self, previous_entries=[]):
# Extract required fields from config
input_jsonl_files, input_jsonl_filter, output_file = (
self.config.input_files,
@ -38,14 +38,8 @@ class JsonlToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries,
previous_entries,
key="compiled",
logger=logger,
current_entries, previous_entries, key="compiled", logger=logger
)
with timer("Write entries to JSONL file", logger):
@ -54,10 +48,7 @@ class JsonlToJsonl(TextToJsonl):
jsonl_data = JsonlToJsonl.convert_entries_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
return entries_with_ids

View file

@ -10,7 +10,7 @@ from typing import List
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry, TextContentConfig
@ -23,7 +23,7 @@ class MarkdownToJsonl(TextToJsonl):
self.config = config
# Define Functions
def process(self, previous_entries=None):
def process(self, previous_entries=[]):
# Extract required fields from config
markdown_files, markdown_file_filter, output_file = (
self.config.input_files,
@ -51,9 +51,6 @@ class MarkdownToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
@ -64,10 +61,7 @@ class MarkdownToJsonl(TextToJsonl):
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
return entries_with_ids

View file

@ -8,7 +8,7 @@ import requests
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, NotionContentConfig
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry
from enum import Enum
@ -80,7 +80,7 @@ class NotionToJsonl(TextToJsonl):
self.body_params = {"page_size": 100}
def process(self, previous_entries=None):
def process(self, previous_entries=[]):
current_entries = []
# Get all pages
@ -240,9 +240,6 @@ class NotionToJsonl(TextToJsonl):
def update_entries_with_ids(self, current_entries, previous_entries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
@ -253,9 +250,6 @@ class NotionToJsonl(TextToJsonl):
jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries)
# Compress JSONL formatted Data
if self.config.compressed_jsonl.suffix == ".gz":
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
elif self.config.compressed_jsonl.suffix == ".jsonl":
dump_jsonl(jsonl_data, self.config.compressed_jsonl)
return entries_with_ids

View file

@ -8,7 +8,7 @@ from typing import Iterable, List
from khoj.processor.org_mode import orgnode
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry, TextContentConfig
from khoj.utils import state
@ -22,7 +22,7 @@ class OrgToJsonl(TextToJsonl):
self.config = config
# Define Functions
def process(self, previous_entries: List[Entry] = None):
def process(self, previous_entries: List[Entry] = []):
# Extract required fields from config
org_files, org_file_filter, output_file = (
self.config.input_files,
@ -51,9 +51,7 @@ class OrgToJsonl(TextToJsonl):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
# Identify, mark and merge any new entries with previous entries
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
@ -64,10 +62,7 @@ class OrgToJsonl(TextToJsonl):
jsonl_data = self.convert_org_entries_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
return entries_with_ids
@ -125,9 +120,13 @@ class OrgToJsonl(TextToJsonl):
# Ignore title notes i.e notes with just headings and empty body
continue
todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else ""
# Prepend filename as top heading to entry
filename = Path(entry_to_file_map[parsed_entry]).stem
heading = f"* {filename}\n** {parsed_entry.heading}." if parsed_entry.heading else f"* {filename}."
if parsed_entry.heading:
heading = f"* {filename}\n** {todo_str}{parsed_entry.heading}."
else:
heading = f"* {filename}."
compiled = heading
if state.verbose > 2:

View file

@ -10,7 +10,7 @@ from langchain.document_loaders import PyPDFLoader
# Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry
@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
class PdfToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=None):
def process(self, previous_entries=[]):
# Extract required fields from config
pdf_files, pdf_file_filter, output_file = (
self.config.input_files,
@ -45,9 +45,6 @@ class PdfToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
@ -58,10 +55,7 @@ class PdfToJsonl(TextToJsonl):
jsonl_data = PdfToJsonl.convert_pdf_maps_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
return entries_with_ids

View file

@ -17,7 +17,7 @@ class TextToJsonl(ABC):
self.config = config
@abstractmethod
def process(self, previous_entries: List[Entry] = None) -> List[Tuple[int, Entry]]:
def process(self, previous_entries: List[Entry] = []) -> List[Tuple[int, Entry]]:
...
@staticmethod
@ -78,16 +78,23 @@ class TextToJsonl(ABC):
# All entries that exist in both current and previous sets are kept
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# load new entries in the order in which they are processed for a stable sort
new_entries = [
(current_entry_hashes.index(entry_hash), hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
new_entries_sorted = sorted(new_entries, key=lambda e: e[0])
# Mark new entries with -1 id to flag for later embeddings generation
new_entries = [(-1, hash_to_current_entries[entry_hash]) for entry_hash in new_entry_hashes]
new_entries_sorted = [(-1, entry[1]) for entry in new_entries_sorted]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
for entry_hash in existing_entry_hashes
]
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries
entries_with_ids = existing_entries_sorted + new_entries_sorted
return entries_with_ids

View file

@ -5,20 +5,20 @@ import time
import yaml
import logging
import json
from typing import List, Optional, Union
from typing import Iterable, List, Optional, Union
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request
from sentence_transformers import util
# Internal Packages
from khoj.configure import configure_content, configure_processor, configure_search
from khoj.configure import configure_processor, configure_server
from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils.config import TextSearchModel
from khoj.utils.helpers import log_telemetry, timer
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import (
ContentConfig,
FullConfig,
@ -524,34 +524,26 @@ def update(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
if not state.config:
error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
logger.warning(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
try:
state.search_index_lock.acquire()
try:
if state.config and state.config.search_type:
state.search_models = configure_search(state.search_models, state.config.search_type)
if state.search_models:
state.content_index = configure_content(
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
)
configure_server(state.config, regenerate=force or False, search_type=t)
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
finally:
state.search_index_lock.release()
except ValueError as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
error_msg = f"🚨 Failed to update server via API: {e}"
logger.error(error_msg, exc_info=True)
raise HTTPException(status_code=500, detail=error_msg)
else:
logger.info("📬 Search index updated via API")
try:
if state.config and state.config.processor:
state.processor_config = configure_processor(state.config.processor)
except ValueError as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
else:
logger.info("📬 Processor reconfigured via API")
components = []
if state.search_models:
components.append("Search models")
if state.content_index:
components.append("Content index")
if state.processor_config:
components.append("Conversation processor")
components_msg = ", ".join(components)
logger.info(f"📬 {components_msg} updated via API")
update_telemetry_state(
request=request,

View file

@ -58,41 +58,46 @@ def extract_entries(jsonl_file) -> List[Entry]:
def compute_embeddings(
entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False
entries_with_ids: List[Tuple[int, Entry]],
bi_encoder: BaseEncoder,
embeddings_file: Path,
regenerate=False,
normalize=True,
):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = []
new_embeddings = torch.tensor([], device=state.device)
existing_embeddings = torch.tensor([], device=state.device)
create_index_msg = ""
# Load pre-computed embeddings from file if exists and update them if required
if embeddings_file.exists() and not regenerate:
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
else:
corpus_embeddings = torch.tensor([], device=state.device)
create_index_msg = " Creating index from scratch."
# Encode any new entries in the corpus and update corpus embeddings
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
if new_entries:
logger.info(f"📩 Indexing {len(new_entries)} text entries.")
logger.info(f"📩 Indexing {len(new_entries)} text entries.{create_index_msg}")
new_embeddings = bi_encoder.encode(
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
)
# Extract existing embeddings from previous corpus embeddings
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
if existing_entry_ids:
existing_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
)
else:
existing_embeddings = torch.tensor([], device=state.device)
# Set corpus embeddings to merger of existing and new embeddings
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
# Else compute the corpus embeddings from scratch
else:
new_entries = [entry.compiled for _, entry in entries_with_ids]
logger.info(f"📩 Indexing {len(new_entries)} text entries. Creating index from scratch.")
corpus_embeddings = bi_encoder.encode(
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
)
if normalize:
# Normalize embeddings for faster lookup via dot product when querying
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
# Save regenerated or updated embeddings to file
if new_entries:
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
torch.save(corpus_embeddings, embeddings_file)
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
@ -173,13 +178,14 @@ def setup(
bi_encoder: BaseEncoder,
regenerate: bool,
filters: List[BaseFilter] = [],
normalize: bool = True,
) -> TextContent:
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
previous_entries = (
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
)
entries_with_indices = text_to_jsonl(config).process(previous_entries or [])
previous_entries = []
if config.compressed_jsonl.exists() and not regenerate:
previous_entries = extract_entries(config.compressed_jsonl)
entries_with_indices = text_to_jsonl(config).process(previous_entries)
# Extract Updated Entries
entries = extract_entries(config.compressed_jsonl)
@ -190,7 +196,7 @@ def setup(
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize
)
for filter in filters:

View file

@ -20,7 +20,7 @@ def load_jsonl(input_path):
# Open JSONL file
if input_path.suffix == ".gz":
jsonl_file = gzip.open(get_absolute_path(input_path), "rt", encoding="utf-8")
elif input_path.suffix == ".jsonl":
else:
jsonl_file = open(get_absolute_path(input_path), "r", encoding="utf-8")
# Read JSONL file
@ -36,17 +36,6 @@ def load_jsonl(input_path):
return data
def dump_jsonl(jsonl_data, output_path):
"Write List of JSON objects to JSON line file"
# Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
f.write(jsonl_data)
logger.debug(f"Wrote jsonl data to {output_path}")
def compress_jsonl_data(jsonl_data, output_path):
# Create output directory, if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)

View file

@ -24,7 +24,7 @@ host: str = None
port: int = None
cli_args: List[str] = None
query_cache = LRU()
search_index_lock = threading.Lock()
config_lock = threading.Lock()
SearchType = utils_config.SearchType
telemetry: List[Dict[str, str]] = []
previous_query: str = None

View file

@ -90,7 +90,7 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
content_config.org = TextContentConfig(
input_files=None,
input_filter=["tests/data/org/*.org"],
compressed_jsonl=content_dir.joinpath("notes.jsonl"),
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
)
@ -101,7 +101,7 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
content_config.plugins = {
"plugin1": TextContentConfig(
input_files=[content_dir.joinpath("notes.jsonl")],
input_files=[content_dir.joinpath("notes.jsonl.gz")],
input_filter=None,
compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
@ -142,7 +142,7 @@ def md_content_config(tmp_path_factory):
content_config.markdown = TextContentConfig(
input_files=None,
input_filter=["tests/data/markdown/*.markdown"],
compressed_jsonl=content_dir.joinpath("markdown.jsonl"),
compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"),
embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
)

View file

@ -5,6 +5,7 @@ import os
# External Packages
import pytest
import torch
from khoj.utils.config import SearchModels
# Internal Packages
@ -17,7 +18,7 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
# Test
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup_with_missing_file_raises_error(
def test_text_search_setup_with_missing_file_raises_error(
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
):
# Arrange
@ -32,7 +33,7 @@ def test_asymmetric_setup_with_missing_file_raises_error(
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup_with_empty_file_raises_error(
def test_text_search_setup_with_empty_file_raises_error(
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
):
# Act
@ -42,7 +43,7 @@ def test_asymmetric_setup_with_empty_file_raises_error(
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels):
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
# Act
# Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(
@ -55,7 +56,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchMo
# ----------------------------------------------------------------------------------------------------
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog):
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog):
# Arrange
caplog.set_level(logging.INFO, logger="khoj")
@ -70,8 +71,8 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
final_logs = caplog.text
# Assert
assert "📩 Saved computed text embeddings to" in initial_logs
assert "📩 Saved computed text embeddings to" not in final_logs
assert "Creating index from scratch." in initial_logs
assert "Creating index from scratch." not in final_logs
# ----------------------------------------------------------------------------------------------------
@ -121,7 +122,9 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
def test_regenerate_index_with_new_entry(
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
):
# Arrange
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
@ -135,25 +138,20 @@ def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchM
with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
# Act
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
# Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
assert len(regenerated_notes_model.entries) == 11
assert len(regenerated_notes_model.corpus_embeddings) == 11
# Assert
# verify new entry loaded from updated embeddings, entries
assert len(initial_notes_model.entries) == 11
assert len(initial_notes_model.corpus_embeddings) == 11
# verify new entry appended to index, without disrupting order or content of existing entries
error_details = compare_index(initial_notes_model, regenerated_notes_model)
if error_details:
pytest.fail(error_details, False)
# Cleanup
# reset input_files in config to empty list
@ -161,30 +159,101 @@ def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchM
# ----------------------------------------------------------------------------------------------------
def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
def test_update_index_with_duplicate_entries_in_stable_order(
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
):
# Arrange
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
# Insert org-mode entries with same compiled form into new org file
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}{new_entry}")
# Act
# load embeddings, entries, notes model after adding new org-mode file
initial_index = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
# update embeddings, entries, notes model after adding new org-mode file
updated_index = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert len(initial_index.entries) == len(updated_index.entries) == 1
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1
# verify the same entry is added even when there are multiple duplicate entries
error_details = compare_index(initial_index, updated_index)
if error_details:
pytest.fail(error_details)
# ----------------------------------------------------------------------------------------------------
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
# Arrange
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
# Insert org-mode entries with same compiled form into new org file
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}{new_entry} -- Tatooine")
# load embeddings, entries, notes model after adding new org file with 2 entries
initial_index = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
)
# update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}")
# Act
updated_index = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert len(initial_index.entries) == len(updated_index.entries) + 1
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
# verify the same entry is added even when there are multiple duplicate entries
error_details = compare_index(updated_index, initial_index)
if error_details:
pytest.fail(error_details)
# ----------------------------------------------------------------------------------------------------
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
)
# append org-mode entry to first org input file in config
with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
f.write(new_entry)
# Act
# update embeddings, entries with the newly added note
content_config.org.input_files = [f"{new_org_file}"]
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
final_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
)
# Assert
# verify new entry added in updated embeddings, entries
assert len(initial_notes_model.entries) == 11
assert len(initial_notes_model.corpus_embeddings) == 11
assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1
assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1
# verify new entry appended to index, without disrupting order or content of existing entries
error_details = compare_index(initial_notes_model, final_notes_model)
if error_details:
pytest.fail(error_details, False)
# Cleanup
# reset input_files in config to empty list
@ -202,3 +271,25 @@ def test_asymmetric_setup_github(content_config: ContentConfig, search_models: S
# Assert
assert len(github_model.entries) > 1
def compare_index(initial_notes_model, final_notes_model):
mismatched_entries, mismatched_embeddings = [], []
for index in range(len(initial_notes_model.entries)):
if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
mismatched_entries.append(index)
# verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
for index in range(len(initial_notes_model.corpus_embeddings)):
if not torch.equal(final_notes_model.corpus_embeddings[index], initial_notes_model.corpus_embeddings[index]):
mismatched_embeddings.append(index)
error_details = ""
if mismatched_entries:
mismatched_entries_str = ",".join(map(str, mismatched_entries))
error_details += f"Entries at {mismatched_entries_str} not equal\n"
if mismatched_embeddings:
mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
return error_details