mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Merge pull request #325 from khoj-ai/stablize-simplify-content-indexing
## Stabilize and Simplify Content Indexing ### Major Updates -9bcca43
Unify logic to update entries when indexing from scratch or incrementally -89c7819
Unify logic to update embeddings when indexing from scratch or incrementally -6a0297c
Stable sort new entries when marking entries for update -58d86d7
Unify logic to configure server from API or on server start - Create tests to ensure old entries, embeddings in index are unaffected on adding new entries - Refer:1482fd4
,7669b85
,88d1a29
-ad41ef3
Make normalization of embeddings configurable to test this inc73feeb
### Minor Updates -1673bb5
Add todo state to compiled form of each entry -6e70b91
Remove unused `dump_jsonl` helper method -7ad9603
Improve naming of lock -b02323a
Improve naming text search test methods Resolves #190
This commit is contained in:
commit
d00c5da8b7
15 changed files with 280 additions and 208 deletions
|
@ -37,45 +37,63 @@ from khoj.search_filter.file_filter import FileFilter
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def configure_server(args, required=False):
|
||||
if args.config is None:
|
||||
if required:
|
||||
def initialize_server(
|
||||
config: Optional[FullConfig], regenerate: bool, type: Optional[SearchType] = None, required=False
|
||||
):
|
||||
if config is None and required:
|
||||
logger.error(
|
||||
f"Exiting as Khoj is not configured.\nConfigure it via http://localhost:42110/config or by editing {state.config_file}."
|
||||
f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://localhost:42110/config or by editing {state.config_file}."
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
elif config is None:
|
||||
logger.warning(
|
||||
f"Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
|
||||
f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
|
||||
)
|
||||
return
|
||||
else:
|
||||
state.config = args.config
|
||||
return None
|
||||
|
||||
try:
|
||||
configure_server(config, regenerate, type)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to configure server on app load: {e}", exc_info=True)
|
||||
|
||||
|
||||
def configure_server(config: FullConfig, regenerate: bool, search_type: Optional[SearchType] = None):
|
||||
# Update Config
|
||||
state.config = config
|
||||
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(args.config.processor)
|
||||
try:
|
||||
state.config_lock.acquire()
|
||||
state.processor_config = configure_processor(state.config.processor)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to configure processor")
|
||||
raise e
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
|
||||
# Initialize Search Models from Config
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
state.config_lock.acquire()
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error configuring search models on app load: {e}")
|
||||
logger.error(f"🚨 Failed to configure search models")
|
||||
raise e
|
||||
finally:
|
||||
state.search_index_lock.release()
|
||||
state.config_lock.release()
|
||||
|
||||
# Initialize Content from Config
|
||||
if state.search_models:
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
state.config_lock.acquire()
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, args.regenerate
|
||||
state.content_index, state.config.content_type, state.search_models, regenerate, search_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error configuring content index on app load: {e}")
|
||||
logger.error(f"🚨 Failed to index content")
|
||||
raise e
|
||||
finally:
|
||||
state.search_index_lock.release()
|
||||
state.config_lock.release()
|
||||
|
||||
|
||||
def configure_routes(app):
|
||||
|
@ -95,7 +113,7 @@ if not state.demo:
|
|||
@schedule.repeat(schedule.every(61).minutes)
|
||||
def update_search_index():
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
state.config_lock.acquire()
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, regenerate=False
|
||||
)
|
||||
|
@ -103,7 +121,7 @@ if not state.demo:
|
|||
except Exception as e:
|
||||
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
|
||||
finally:
|
||||
state.search_index_lock.release()
|
||||
state.config_lock.release()
|
||||
|
||||
|
||||
def configure_search_types(config: FullConfig):
|
||||
|
@ -118,10 +136,10 @@ def configure_search_types(config: FullConfig):
|
|||
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
||||
|
||||
|
||||
def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]:
|
||||
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
|
||||
# Run Validation Checks
|
||||
if search_config is None:
|
||||
logger.warning("🚨 No Search type is configured.")
|
||||
logger.warning("🚨 No Search configuration available.")
|
||||
return None
|
||||
if search_models is None:
|
||||
search_models = SearchModels()
|
||||
|
@ -147,7 +165,7 @@ def configure_content(
|
|||
) -> Optional[ContentIndex]:
|
||||
# Run Validation Checks
|
||||
if content_config is None:
|
||||
logger.warning("🚨 No Content type is configured.")
|
||||
logger.warning("🚨 No Content configuration available.")
|
||||
return None
|
||||
if content_index is None:
|
||||
content_index = ContentIndex()
|
||||
|
@ -242,9 +260,10 @@ def configure_content(
|
|||
return content_index
|
||||
|
||||
|
||||
def configure_processor(processor_config: ProcessorConfig):
|
||||
def configure_processor(processor_config: Optional[ProcessorConfig]):
|
||||
if not processor_config:
|
||||
return
|
||||
logger.warning("🚨 No Processor configuration available.")
|
||||
return None
|
||||
|
||||
processor = ProcessorConfigModel()
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from rich.logging import RichHandler
|
|||
import schedule
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_routes, configure_server
|
||||
from khoj.configure import configure_routes, initialize_server
|
||||
from khoj.utils import state
|
||||
from khoj.utils.cli import cli
|
||||
|
||||
|
@ -70,7 +70,7 @@ def run():
|
|||
poll_task_scheduler()
|
||||
|
||||
# Start Server
|
||||
configure_server(args, required=False)
|
||||
initialize_server(args.config, args.regenerate, required=False)
|
||||
configure_routes(app)
|
||||
start_server(app, host=args.host, port=args.port, socket=args.socket)
|
||||
else:
|
||||
|
@ -93,7 +93,7 @@ def run():
|
|||
tray.show()
|
||||
|
||||
# Setup Server
|
||||
configure_server(args, required=False)
|
||||
initialize_server(args.config, args.regenerate, required=False)
|
||||
configure_routes(app)
|
||||
server = ServerThread(start_server_func=lambda: start_server(app, host=args.host, port=args.port))
|
||||
|
||||
|
|
|
@ -13,9 +13,8 @@ from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
|
|||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.utils import state
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -38,7 +37,7 @@ class GithubToJsonl(TextToJsonl):
|
|||
else:
|
||||
return
|
||||
|
||||
def process(self, previous_entries=None):
|
||||
def process(self, previous_entries=[]):
|
||||
current_entries = []
|
||||
for repo in self.config.repos:
|
||||
current_entries += self.process_repo(repo)
|
||||
|
@ -98,10 +97,7 @@ class GithubToJsonl(TextToJsonl):
|
|||
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if self.config.compressed_jsonl.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
|
||||
elif self.config.compressed_jsonl.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, self.config.compressed_jsonl)
|
||||
|
||||
return entries_with_ids
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import List
|
|||
# Internal Packages
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.helpers import get_absolute_path, timer
|
||||
from khoj.utils.jsonl import load_jsonl, dump_jsonl, compress_jsonl_data
|
||||
from khoj.utils.jsonl import load_jsonl, compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class JsonlToJsonl(TextToJsonl):
|
||||
# Define Functions
|
||||
def process(self, previous_entries=None):
|
||||
def process(self, previous_entries=[]):
|
||||
# Extract required fields from config
|
||||
input_jsonl_files, input_jsonl_filter, output_file = (
|
||||
self.config.input_files,
|
||||
|
@ -38,14 +38,8 @@ class JsonlToJsonl(TextToJsonl):
|
|||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||
current_entries,
|
||||
previous_entries,
|
||||
key="compiled",
|
||||
logger=logger,
|
||||
current_entries, previous_entries, key="compiled", logger=logger
|
||||
)
|
||||
|
||||
with timer("Write entries to JSONL file", logger):
|
||||
|
@ -54,10 +48,7 @@ class JsonlToJsonl(TextToJsonl):
|
|||
jsonl_data = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries_with_ids
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import List
|
|||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||
|
||||
|
||||
|
@ -23,7 +23,7 @@ class MarkdownToJsonl(TextToJsonl):
|
|||
self.config = config
|
||||
|
||||
# Define Functions
|
||||
def process(self, previous_entries=None):
|
||||
def process(self, previous_entries=[]):
|
||||
# Extract required fields from config
|
||||
markdown_files, markdown_file_filter, output_file = (
|
||||
self.config.input_files,
|
||||
|
@ -51,9 +51,6 @@ class MarkdownToJsonl(TextToJsonl):
|
|||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||
current_entries, previous_entries, key="compiled", logger=logger
|
||||
)
|
||||
|
@ -64,10 +61,7 @@ class MarkdownToJsonl(TextToJsonl):
|
|||
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries_with_ids
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import requests
|
|||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import Entry, NotionContentConfig
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
from enum import Enum
|
||||
|
@ -80,7 +80,7 @@ class NotionToJsonl(TextToJsonl):
|
|||
|
||||
self.body_params = {"page_size": 100}
|
||||
|
||||
def process(self, previous_entries=None):
|
||||
def process(self, previous_entries=[]):
|
||||
current_entries = []
|
||||
|
||||
# Get all pages
|
||||
|
@ -240,9 +240,6 @@ class NotionToJsonl(TextToJsonl):
|
|||
def update_entries_with_ids(self, current_entries, previous_entries):
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||
current_entries, previous_entries, key="compiled", logger=logger
|
||||
)
|
||||
|
@ -253,9 +250,6 @@ class NotionToJsonl(TextToJsonl):
|
|||
jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if self.config.compressed_jsonl.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
|
||||
elif self.config.compressed_jsonl.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, self.config.compressed_jsonl)
|
||||
|
||||
return entries_with_ids
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Iterable, List
|
|||
from khoj.processor.org_mode import orgnode
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||
from khoj.utils import state
|
||||
|
||||
|
@ -22,7 +22,7 @@ class OrgToJsonl(TextToJsonl):
|
|||
self.config = config
|
||||
|
||||
# Define Functions
|
||||
def process(self, previous_entries: List[Entry] = None):
|
||||
def process(self, previous_entries: List[Entry] = []):
|
||||
# Extract required fields from config
|
||||
org_files, org_file_filter, output_file = (
|
||||
self.config.input_files,
|
||||
|
@ -51,9 +51,7 @@ class OrgToJsonl(TextToJsonl):
|
|||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
with timer("Identify new or updated entries", logger):
|
||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||
current_entries, previous_entries, key="compiled", logger=logger
|
||||
)
|
||||
|
@ -64,10 +62,7 @@ class OrgToJsonl(TextToJsonl):
|
|||
jsonl_data = self.convert_org_entries_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries_with_ids
|
||||
|
||||
|
@ -125,9 +120,13 @@ class OrgToJsonl(TextToJsonl):
|
|||
# Ignore title notes i.e notes with just headings and empty body
|
||||
continue
|
||||
|
||||
todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else ""
|
||||
# Prepend filename as top heading to entry
|
||||
filename = Path(entry_to_file_map[parsed_entry]).stem
|
||||
heading = f"* {filename}\n** {parsed_entry.heading}." if parsed_entry.heading else f"* {filename}."
|
||||
if parsed_entry.heading:
|
||||
heading = f"* {filename}\n** {todo_str}{parsed_entry.heading}."
|
||||
else:
|
||||
heading = f"* {filename}."
|
||||
|
||||
compiled = heading
|
||||
if state.verbose > 2:
|
||||
|
|
|
@ -10,7 +10,7 @@ from langchain.document_loaders import PyPDFLoader
|
|||
# Internal Packages
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class PdfToJsonl(TextToJsonl):
|
||||
# Define Functions
|
||||
def process(self, previous_entries=None):
|
||||
def process(self, previous_entries=[]):
|
||||
# Extract required fields from config
|
||||
pdf_files, pdf_file_filter, output_file = (
|
||||
self.config.input_files,
|
||||
|
@ -45,9 +45,6 @@ class PdfToJsonl(TextToJsonl):
|
|||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||
current_entries, previous_entries, key="compiled", logger=logger
|
||||
)
|
||||
|
@ -58,10 +55,7 @@ class PdfToJsonl(TextToJsonl):
|
|||
jsonl_data = PdfToJsonl.convert_pdf_maps_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
|
||||
return entries_with_ids
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ class TextToJsonl(ABC):
|
|||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def process(self, previous_entries: List[Entry] = None) -> List[Tuple[int, Entry]]:
|
||||
def process(self, previous_entries: List[Entry] = []) -> List[Tuple[int, Entry]]:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
|
@ -78,16 +78,23 @@ class TextToJsonl(ABC):
|
|||
# All entries that exist in both current and previous sets are kept
|
||||
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
||||
|
||||
# load new entries in the order in which they are processed for a stable sort
|
||||
new_entries = [
|
||||
(current_entry_hashes.index(entry_hash), hash_to_current_entries[entry_hash])
|
||||
for entry_hash in new_entry_hashes
|
||||
]
|
||||
new_entries_sorted = sorted(new_entries, key=lambda e: e[0])
|
||||
# Mark new entries with -1 id to flag for later embeddings generation
|
||||
new_entries = [(-1, hash_to_current_entries[entry_hash]) for entry_hash in new_entry_hashes]
|
||||
new_entries_sorted = [(-1, entry[1]) for entry in new_entries_sorted]
|
||||
|
||||
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
||||
existing_entries = [
|
||||
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
||||
for entry_hash in existing_entry_hashes
|
||||
]
|
||||
|
||||
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
|
||||
entries_with_ids = existing_entries_sorted + new_entries
|
||||
|
||||
entries_with_ids = existing_entries_sorted + new_entries_sorted
|
||||
|
||||
return entries_with_ids
|
||||
|
||||
|
|
|
@ -5,20 +5,20 @@ import time
|
|||
import yaml
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Optional, Union
|
||||
from typing import Iterable, List, Optional, Union
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, HTTPException, Header, Request
|
||||
from sentence_transformers import util
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_content, configure_processor, configure_search
|
||||
from khoj.configure import configure_processor, configure_server
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils.config import TextSearchModel
|
||||
from khoj.utils.helpers import log_telemetry, timer
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
FullConfig,
|
||||
|
@ -524,34 +524,26 @@ def update(
|
|||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
):
|
||||
if not state.config:
|
||||
error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
|
||||
logger.warning(error_msg)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
try:
|
||||
state.search_index_lock.acquire()
|
||||
try:
|
||||
if state.config and state.config.search_type:
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
if state.search_models:
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
|
||||
)
|
||||
configure_server(state.config, regenerate=force or False, search_type=t)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
state.search_index_lock.release()
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
error_msg = f"🚨 Failed to update server via API: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
else:
|
||||
logger.info("📬 Search index updated via API")
|
||||
|
||||
try:
|
||||
if state.config and state.config.processor:
|
||||
state.processor_config = configure_processor(state.config.processor)
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
else:
|
||||
logger.info("📬 Processor reconfigured via API")
|
||||
components = []
|
||||
if state.search_models:
|
||||
components.append("Search models")
|
||||
if state.content_index:
|
||||
components.append("Content index")
|
||||
if state.processor_config:
|
||||
components.append("Conversation processor")
|
||||
components_msg = ", ".join(components)
|
||||
logger.info(f"📬 {components_msg} updated via API")
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
|
|
@ -58,41 +58,46 @@ def extract_entries(jsonl_file) -> List[Entry]:
|
|||
|
||||
|
||||
def compute_embeddings(
|
||||
entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False
|
||||
entries_with_ids: List[Tuple[int, Entry]],
|
||||
bi_encoder: BaseEncoder,
|
||||
embeddings_file: Path,
|
||||
regenerate=False,
|
||||
normalize=True,
|
||||
):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
new_entries = []
|
||||
new_embeddings = torch.tensor([], device=state.device)
|
||||
existing_embeddings = torch.tensor([], device=state.device)
|
||||
create_index_msg = ""
|
||||
# Load pre-computed embeddings from file if exists and update them if required
|
||||
if embeddings_file.exists() and not regenerate:
|
||||
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
|
||||
else:
|
||||
corpus_embeddings = torch.tensor([], device=state.device)
|
||||
create_index_msg = " Creating index from scratch."
|
||||
|
||||
# Encode any new entries in the corpus and update corpus embeddings
|
||||
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
|
||||
if new_entries:
|
||||
logger.info(f"📩 Indexing {len(new_entries)} text entries.")
|
||||
logger.info(f"📩 Indexing {len(new_entries)} text entries.{create_index_msg}")
|
||||
new_embeddings = bi_encoder.encode(
|
||||
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
|
||||
)
|
||||
|
||||
# Extract existing embeddings from previous corpus embeddings
|
||||
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
||||
if existing_entry_ids:
|
||||
existing_embeddings = torch.index_select(
|
||||
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
|
||||
)
|
||||
else:
|
||||
existing_embeddings = torch.tensor([], device=state.device)
|
||||
|
||||
# Set corpus embeddings to merger of existing and new embeddings
|
||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||
# Else compute the corpus embeddings from scratch
|
||||
else:
|
||||
new_entries = [entry.compiled for _, entry in entries_with_ids]
|
||||
logger.info(f"📩 Indexing {len(new_entries)} text entries. Creating index from scratch.")
|
||||
corpus_embeddings = bi_encoder.encode(
|
||||
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
|
||||
)
|
||||
if normalize:
|
||||
# Normalize embeddings for faster lookup via dot product when querying
|
||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||
|
||||
# Save regenerated or updated embeddings to file
|
||||
if new_entries:
|
||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||
torch.save(corpus_embeddings, embeddings_file)
|
||||
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
|
||||
|
||||
|
@ -173,13 +178,14 @@ def setup(
|
|||
bi_encoder: BaseEncoder,
|
||||
regenerate: bool,
|
||||
filters: List[BaseFilter] = [],
|
||||
normalize: bool = True,
|
||||
) -> TextContent:
|
||||
# Map notes in text files to (compressed) JSONL formatted file
|
||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||
previous_entries = (
|
||||
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
|
||||
)
|
||||
entries_with_indices = text_to_jsonl(config).process(previous_entries or [])
|
||||
previous_entries = []
|
||||
if config.compressed_jsonl.exists() and not regenerate:
|
||||
previous_entries = extract_entries(config.compressed_jsonl)
|
||||
entries_with_indices = text_to_jsonl(config).process(previous_entries)
|
||||
|
||||
# Extract Updated Entries
|
||||
entries = extract_entries(config.compressed_jsonl)
|
||||
|
@ -190,7 +196,7 @@ def setup(
|
|||
# Compute or Load Embeddings
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = compute_embeddings(
|
||||
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate
|
||||
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize
|
||||
)
|
||||
|
||||
for filter in filters:
|
||||
|
|
|
@ -20,7 +20,7 @@ def load_jsonl(input_path):
|
|||
# Open JSONL file
|
||||
if input_path.suffix == ".gz":
|
||||
jsonl_file = gzip.open(get_absolute_path(input_path), "rt", encoding="utf-8")
|
||||
elif input_path.suffix == ".jsonl":
|
||||
else:
|
||||
jsonl_file = open(get_absolute_path(input_path), "r", encoding="utf-8")
|
||||
|
||||
# Read JSONL file
|
||||
|
@ -36,17 +36,6 @@ def load_jsonl(input_path):
|
|||
return data
|
||||
|
||||
|
||||
def dump_jsonl(jsonl_data, output_path):
|
||||
"Write List of JSON objects to JSON line file"
|
||||
# Create output directory, if it doesn't exist
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(jsonl_data)
|
||||
|
||||
logger.debug(f"Wrote jsonl data to {output_path}")
|
||||
|
||||
|
||||
def compress_jsonl_data(jsonl_data, output_path):
|
||||
# Create output directory, if it doesn't exist
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
@ -24,7 +24,7 @@ host: str = None
|
|||
port: int = None
|
||||
cli_args: List[str] = None
|
||||
query_cache = LRU()
|
||||
search_index_lock = threading.Lock()
|
||||
config_lock = threading.Lock()
|
||||
SearchType = utils_config.SearchType
|
||||
telemetry: List[Dict[str, str]] = []
|
||||
previous_query: str = None
|
||||
|
|
|
@ -90,7 +90,7 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
|
|||
content_config.org = TextContentConfig(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/org/*.org"],
|
||||
compressed_jsonl=content_dir.joinpath("notes.jsonl"),
|
||||
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
|
||||
)
|
||||
|
||||
|
@ -101,7 +101,7 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
|
|||
|
||||
content_config.plugins = {
|
||||
"plugin1": TextContentConfig(
|
||||
input_files=[content_dir.joinpath("notes.jsonl")],
|
||||
input_files=[content_dir.joinpath("notes.jsonl.gz")],
|
||||
input_filter=None,
|
||||
compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
|
||||
|
@ -142,7 +142,7 @@ def md_content_config(tmp_path_factory):
|
|||
content_config.markdown = TextContentConfig(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/markdown/*.markdown"],
|
||||
compressed_jsonl=content_dir.joinpath("markdown.jsonl"),
|
||||
compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
|
||||
)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import os
|
|||
|
||||
# External Packages
|
||||
import pytest
|
||||
import torch
|
||||
from khoj.utils.config import SearchModels
|
||||
|
||||
# Internal Packages
|
||||
|
@ -17,7 +18,7 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
|||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_setup_with_missing_file_raises_error(
|
||||
def test_text_search_setup_with_missing_file_raises_error(
|
||||
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
||||
):
|
||||
# Arrange
|
||||
|
@ -32,7 +33,7 @@ def test_asymmetric_setup_with_missing_file_raises_error(
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_setup_with_empty_file_raises_error(
|
||||
def test_text_search_setup_with_empty_file_raises_error(
|
||||
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
||||
):
|
||||
# Act
|
||||
|
@ -42,7 +43,7 @@ def test_asymmetric_setup_with_empty_file_raises_error(
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels):
|
||||
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
|
||||
# Act
|
||||
# Regenerate notes embeddings during asymmetric setup
|
||||
notes_model = text_search.setup(
|
||||
|
@ -55,7 +56,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchMo
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog):
|
||||
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog):
|
||||
# Arrange
|
||||
caplog.set_level(logging.INFO, logger="khoj")
|
||||
|
||||
|
@ -70,8 +71,8 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
|
|||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
assert "📩 Saved computed text embeddings to" in initial_logs
|
||||
assert "📩 Saved computed text embeddings to" not in final_logs
|
||||
assert "Creating index from scratch." in initial_logs
|
||||
assert "Creating index from scratch." not in final_logs
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -121,7 +122,9 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||
def test_regenerate_index_with_new_entry(
|
||||
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
|
||||
):
|
||||
# Arrange
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
|
@ -135,25 +138,20 @@ def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchM
|
|||
with open(new_org_file, "w") as f:
|
||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||
|
||||
# Act
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
regenerated_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(regenerated_notes_model.entries) == 11
|
||||
assert len(regenerated_notes_model.corpus_embeddings) == 11
|
||||
|
||||
# Assert
|
||||
# verify new entry loaded from updated embeddings, entries
|
||||
assert len(initial_notes_model.entries) == 11
|
||||
assert len(initial_notes_model.corpus_embeddings) == 11
|
||||
# verify new entry appended to index, without disrupting order or content of existing entries
|
||||
error_details = compare_index(initial_notes_model, regenerated_notes_model)
|
||||
if error_details:
|
||||
pytest.fail(error_details, False)
|
||||
|
||||
# Cleanup
|
||||
# reset input_files in config to empty list
|
||||
|
@ -161,30 +159,101 @@ def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchM
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||
def test_update_index_with_duplicate_entries_in_stable_order(
|
||||
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
|
||||
):
|
||||
# Arrange
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
||||
# Insert org-mode entries with same compiled form into new org file
|
||||
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}{new_entry}")
|
||||
|
||||
# Act
|
||||
# load embeddings, entries, notes model after adding new org-mode file
|
||||
initial_index = text_search.setup(
|
||||
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
assert len(initial_notes_model.entries) == 10
|
||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||
# update embeddings, entries, notes model after adding new org-mode file
|
||||
updated_index = text_search.setup(
|
||||
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert len(initial_index.entries) == len(updated_index.entries) == 1
|
||||
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1
|
||||
|
||||
# verify the same entry is added even when there are multiple duplicate entries
|
||||
error_details = compare_index(initial_index, updated_index)
|
||||
if error_details:
|
||||
pytest.fail(error_details)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
||||
# Arrange
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
||||
# Insert org-mode entries with same compiled form into new org file
|
||||
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}{new_entry} -- Tatooine")
|
||||
|
||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
initial_index = text_search.setup(
|
||||
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
# update embeddings, entries, notes model after removing an entry from the org file
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}")
|
||||
|
||||
# Act
|
||||
updated_index = text_search.setup(
|
||||
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert len(initial_index.entries) == len(updated_index.entries) + 1
|
||||
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
|
||||
|
||||
# verify the same entry is added even when there are multiple duplicate entries
|
||||
error_details = compare_index(updated_index, initial_index)
|
||||
if error_details:
|
||||
pytest.fail(error_details)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||
# Arrange
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
|
||||
)
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
with open(new_org_file, "w") as f:
|
||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
f.write(new_entry)
|
||||
|
||||
# Act
|
||||
# update embeddings, entries with the newly added note
|
||||
content_config.org.input_files = [f"{new_org_file}"]
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||
final_notes_model = text_search.setup(
|
||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
|
||||
)
|
||||
|
||||
# Assert
|
||||
# verify new entry added in updated embeddings, entries
|
||||
assert len(initial_notes_model.entries) == 11
|
||||
assert len(initial_notes_model.corpus_embeddings) == 11
|
||||
assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1
|
||||
assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1
|
||||
|
||||
# verify new entry appended to index, without disrupting order or content of existing entries
|
||||
error_details = compare_index(initial_notes_model, final_notes_model)
|
||||
if error_details:
|
||||
pytest.fail(error_details, False)
|
||||
|
||||
# Cleanup
|
||||
# reset input_files in config to empty list
|
||||
|
@ -202,3 +271,25 @@ def test_asymmetric_setup_github(content_config: ContentConfig, search_models: S
|
|||
|
||||
# Assert
|
||||
assert len(github_model.entries) > 1
|
||||
|
||||
|
||||
def compare_index(initial_notes_model, final_notes_model):
|
||||
mismatched_entries, mismatched_embeddings = [], []
|
||||
for index in range(len(initial_notes_model.entries)):
|
||||
if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
|
||||
mismatched_entries.append(index)
|
||||
|
||||
# verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
|
||||
for index in range(len(initial_notes_model.corpus_embeddings)):
|
||||
if not torch.equal(final_notes_model.corpus_embeddings[index], initial_notes_model.corpus_embeddings[index]):
|
||||
mismatched_embeddings.append(index)
|
||||
|
||||
error_details = ""
|
||||
if mismatched_entries:
|
||||
mismatched_entries_str = ",".join(map(str, mismatched_entries))
|
||||
error_details += f"Entries at {mismatched_entries_str} not equal\n"
|
||||
if mismatched_embeddings:
|
||||
mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
|
||||
error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
|
||||
|
||||
return error_details
|
||||
|
|
Loading…
Reference in a new issue