mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Merge pull request #325 from khoj-ai/stablize-simplify-content-indexing
## Stabilize and Simplify Content Indexing ### Major Updates -9bcca43
Unify logic to update entries when indexing from scratch or incrementally -89c7819
Unify logic to update embeddings when indexing from scratch or incrementally -6a0297c
Stable sort new entries when marking entries for update -58d86d7
Unify logic to configure server from API or on server start - Create tests to ensure old entries, embeddings in index are unaffected on adding new entries - Refer:1482fd4
,7669b85
,88d1a29
-ad41ef3
Make normalization of embeddings configurable to test this inc73feeb
### Minor Updates -1673bb5
Add todo state to compiled form of each entry -6e70b91
Remove unused `dump_jsonl` helper method -7ad9603
Improve naming of lock -b02323a
Improve naming text search test methods Resolves #190
This commit is contained in:
commit
d00c5da8b7
15 changed files with 280 additions and 208 deletions
|
@ -37,45 +37,63 @@ from khoj.search_filter.file_filter import FileFilter
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_server(args, required=False):
|
def initialize_server(
|
||||||
if args.config is None:
|
config: Optional[FullConfig], regenerate: bool, type: Optional[SearchType] = None, required=False
|
||||||
if required:
|
):
|
||||||
logger.error(
|
if config is None and required:
|
||||||
f"Exiting as Khoj is not configured.\nConfigure it via http://localhost:42110/config or by editing {state.config_file}."
|
logger.error(
|
||||||
)
|
f"🚨 Exiting as Khoj is not configured.\nConfigure it via http://localhost:42110/config or by editing {state.config_file}."
|
||||||
sys.exit(1)
|
)
|
||||||
else:
|
sys.exit(1)
|
||||||
logger.warning(
|
elif config is None:
|
||||||
f"Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
|
logger.warning(
|
||||||
)
|
f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
|
||||||
return
|
)
|
||||||
else:
|
return None
|
||||||
state.config = args.config
|
|
||||||
|
try:
|
||||||
|
configure_server(config, regenerate, type)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"🚨 Failed to configure server on app load: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_server(config: FullConfig, regenerate: bool, search_type: Optional[SearchType] = None):
|
||||||
|
# Update Config
|
||||||
|
state.config = config
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
state.processor_config = configure_processor(args.config.processor)
|
try:
|
||||||
|
state.config_lock.acquire()
|
||||||
|
state.processor_config = configure_processor(state.config.processor)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"🚨 Failed to configure processor")
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
state.config_lock.release()
|
||||||
|
|
||||||
# Initialize Search Models from Config
|
# Initialize Search Models from Config
|
||||||
try:
|
try:
|
||||||
state.search_index_lock.acquire()
|
state.config_lock.acquire()
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types(state.config)
|
||||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"🚨 Error configuring search models on app load: {e}")
|
logger.error(f"🚨 Failed to configure search models")
|
||||||
|
raise e
|
||||||
finally:
|
finally:
|
||||||
state.search_index_lock.release()
|
state.config_lock.release()
|
||||||
|
|
||||||
# Initialize Content from Config
|
# Initialize Content from Config
|
||||||
if state.search_models:
|
if state.search_models:
|
||||||
try:
|
try:
|
||||||
state.search_index_lock.acquire()
|
state.config_lock.acquire()
|
||||||
state.content_index = configure_content(
|
state.content_index = configure_content(
|
||||||
state.content_index, state.config.content_type, state.search_models, args.regenerate
|
state.content_index, state.config.content_type, state.search_models, regenerate, search_type
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"🚨 Error configuring content index on app load: {e}")
|
logger.error(f"🚨 Failed to index content")
|
||||||
|
raise e
|
||||||
finally:
|
finally:
|
||||||
state.search_index_lock.release()
|
state.config_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def configure_routes(app):
|
def configure_routes(app):
|
||||||
|
@ -95,7 +113,7 @@ if not state.demo:
|
||||||
@schedule.repeat(schedule.every(61).minutes)
|
@schedule.repeat(schedule.every(61).minutes)
|
||||||
def update_search_index():
|
def update_search_index():
|
||||||
try:
|
try:
|
||||||
state.search_index_lock.acquire()
|
state.config_lock.acquire()
|
||||||
state.content_index = configure_content(
|
state.content_index = configure_content(
|
||||||
state.content_index, state.config.content_type, state.search_models, regenerate=False
|
state.content_index, state.config.content_type, state.search_models, regenerate=False
|
||||||
)
|
)
|
||||||
|
@ -103,7 +121,7 @@ if not state.demo:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
|
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
|
||||||
finally:
|
finally:
|
||||||
state.search_index_lock.release()
|
state.config_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def configure_search_types(config: FullConfig):
|
def configure_search_types(config: FullConfig):
|
||||||
|
@ -118,10 +136,10 @@ def configure_search_types(config: FullConfig):
|
||||||
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
||||||
|
|
||||||
|
|
||||||
def configure_search(search_models: SearchModels, search_config: SearchConfig) -> Optional[SearchModels]:
|
def configure_search(search_models: SearchModels, search_config: Optional[SearchConfig]) -> Optional[SearchModels]:
|
||||||
# Run Validation Checks
|
# Run Validation Checks
|
||||||
if search_config is None:
|
if search_config is None:
|
||||||
logger.warning("🚨 No Search type is configured.")
|
logger.warning("🚨 No Search configuration available.")
|
||||||
return None
|
return None
|
||||||
if search_models is None:
|
if search_models is None:
|
||||||
search_models = SearchModels()
|
search_models = SearchModels()
|
||||||
|
@ -147,7 +165,7 @@ def configure_content(
|
||||||
) -> Optional[ContentIndex]:
|
) -> Optional[ContentIndex]:
|
||||||
# Run Validation Checks
|
# Run Validation Checks
|
||||||
if content_config is None:
|
if content_config is None:
|
||||||
logger.warning("🚨 No Content type is configured.")
|
logger.warning("🚨 No Content configuration available.")
|
||||||
return None
|
return None
|
||||||
if content_index is None:
|
if content_index is None:
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
|
@ -242,9 +260,10 @@ def configure_content(
|
||||||
return content_index
|
return content_index
|
||||||
|
|
||||||
|
|
||||||
def configure_processor(processor_config: ProcessorConfig):
|
def configure_processor(processor_config: Optional[ProcessorConfig]):
|
||||||
if not processor_config:
|
if not processor_config:
|
||||||
return
|
logger.warning("🚨 No Processor configuration available.")
|
||||||
|
return None
|
||||||
|
|
||||||
processor = ProcessorConfigModel()
|
processor = ProcessorConfigModel()
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ from rich.logging import RichHandler
|
||||||
import schedule
|
import schedule
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_routes, configure_server
|
from khoj.configure import configure_routes, initialize_server
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.cli import cli
|
from khoj.utils.cli import cli
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ def run():
|
||||||
poll_task_scheduler()
|
poll_task_scheduler()
|
||||||
|
|
||||||
# Start Server
|
# Start Server
|
||||||
configure_server(args, required=False)
|
initialize_server(args.config, args.regenerate, required=False)
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
start_server(app, host=args.host, port=args.port, socket=args.socket)
|
start_server(app, host=args.host, port=args.port, socket=args.socket)
|
||||||
else:
|
else:
|
||||||
|
@ -93,7 +93,7 @@ def run():
|
||||||
tray.show()
|
tray.show()
|
||||||
|
|
||||||
# Setup Server
|
# Setup Server
|
||||||
configure_server(args, required=False)
|
initialize_server(args.config, args.regenerate, required=False)
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
server = ServerThread(start_server_func=lambda: start_server(app, host=args.host, port=args.port))
|
server = ServerThread(start_server_func=lambda: start_server(app, host=args.host, port=args.port))
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,8 @@ from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
|
||||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
from khoj.utils.jsonl import compress_jsonl_data
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
from khoj.utils import state
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -38,7 +37,7 @@ class GithubToJsonl(TextToJsonl):
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
def process(self, previous_entries=None):
|
def process(self, previous_entries=[]):
|
||||||
current_entries = []
|
current_entries = []
|
||||||
for repo in self.config.repos:
|
for repo in self.config.repos:
|
||||||
current_entries += self.process_repo(repo)
|
current_entries += self.process_repo(repo)
|
||||||
|
@ -98,10 +97,7 @@ class GithubToJsonl(TextToJsonl):
|
||||||
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
# Compress JSONL formatted Data
|
||||||
if self.config.compressed_jsonl.suffix == ".gz":
|
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
|
||||||
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
|
|
||||||
elif self.config.compressed_jsonl.suffix == ".jsonl":
|
|
||||||
dump_jsonl(jsonl_data, self.config.compressed_jsonl)
|
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from typing import List
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||||
from khoj.utils.helpers import get_absolute_path, timer
|
from khoj.utils.helpers import get_absolute_path, timer
|
||||||
from khoj.utils.jsonl import load_jsonl, dump_jsonl, compress_jsonl_data
|
from khoj.utils.jsonl import load_jsonl, compress_jsonl_data
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class JsonlToJsonl(TextToJsonl):
|
class JsonlToJsonl(TextToJsonl):
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, previous_entries=None):
|
def process(self, previous_entries=[]):
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
input_jsonl_files, input_jsonl_filter, output_file = (
|
input_jsonl_files, input_jsonl_filter, output_file = (
|
||||||
self.config.input_files,
|
self.config.input_files,
|
||||||
|
@ -38,15 +38,9 @@ class JsonlToJsonl(TextToJsonl):
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
if not previous_entries:
|
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||||
entries_with_ids = list(enumerate(current_entries))
|
current_entries, previous_entries, key="compiled", logger=logger
|
||||||
else:
|
)
|
||||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
|
||||||
current_entries,
|
|
||||||
previous_entries,
|
|
||||||
key="compiled",
|
|
||||||
logger=logger,
|
|
||||||
)
|
|
||||||
|
|
||||||
with timer("Write entries to JSONL file", logger):
|
with timer("Write entries to JSONL file", logger):
|
||||||
# Process Each Entry from All Notes Files
|
# Process Each Entry from All Notes Files
|
||||||
|
@ -54,10 +48,7 @@ class JsonlToJsonl(TextToJsonl):
|
||||||
jsonl_data = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
jsonl_data = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
# Compress JSONL formatted Data
|
||||||
if output_file.suffix == ".gz":
|
compress_jsonl_data(jsonl_data, output_file)
|
||||||
compress_jsonl_data(jsonl_data, output_file)
|
|
||||||
elif output_file.suffix == ".jsonl":
|
|
||||||
dump_jsonl(jsonl_data, output_file)
|
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from typing import List
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||||
from khoj.utils.constants import empty_escape_sequences
|
from khoj.utils.constants import empty_escape_sequences
|
||||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
from khoj.utils.jsonl import compress_jsonl_data
|
||||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, previous_entries=None):
|
def process(self, previous_entries=[]):
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
markdown_files, markdown_file_filter, output_file = (
|
markdown_files, markdown_file_filter, output_file = (
|
||||||
self.config.input_files,
|
self.config.input_files,
|
||||||
|
@ -51,12 +51,9 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
if not previous_entries:
|
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||||
entries_with_ids = list(enumerate(current_entries))
|
current_entries, previous_entries, key="compiled", logger=logger
|
||||||
else:
|
)
|
||||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
|
||||||
current_entries, previous_entries, key="compiled", logger=logger
|
|
||||||
)
|
|
||||||
|
|
||||||
with timer("Write markdown entries to JSONL file", logger):
|
with timer("Write markdown entries to JSONL file", logger):
|
||||||
# Process Each Entry from All Notes Files
|
# Process Each Entry from All Notes Files
|
||||||
|
@ -64,10 +61,7 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
# Compress JSONL formatted Data
|
||||||
if output_file.suffix == ".gz":
|
compress_jsonl_data(jsonl_data, output_file)
|
||||||
compress_jsonl_data(jsonl_data, output_file)
|
|
||||||
elif output_file.suffix == ".jsonl":
|
|
||||||
dump_jsonl(jsonl_data, output_file)
|
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import requests
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import Entry, NotionContentConfig
|
from khoj.utils.rawconfig import Entry, NotionContentConfig
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
from khoj.utils.jsonl import compress_jsonl_data
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -80,7 +80,7 @@ class NotionToJsonl(TextToJsonl):
|
||||||
|
|
||||||
self.body_params = {"page_size": 100}
|
self.body_params = {"page_size": 100}
|
||||||
|
|
||||||
def process(self, previous_entries=None):
|
def process(self, previous_entries=[]):
|
||||||
current_entries = []
|
current_entries = []
|
||||||
|
|
||||||
# Get all pages
|
# Get all pages
|
||||||
|
@ -240,12 +240,9 @@ class NotionToJsonl(TextToJsonl):
|
||||||
def update_entries_with_ids(self, current_entries, previous_entries):
|
def update_entries_with_ids(self, current_entries, previous_entries):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
if not previous_entries:
|
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||||
entries_with_ids = list(enumerate(current_entries))
|
current_entries, previous_entries, key="compiled", logger=logger
|
||||||
else:
|
)
|
||||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
|
||||||
current_entries, previous_entries, key="compiled", logger=logger
|
|
||||||
)
|
|
||||||
|
|
||||||
with timer("Write Notion entries to JSONL file", logger):
|
with timer("Write Notion entries to JSONL file", logger):
|
||||||
# Process Each Entry from all Notion entries
|
# Process Each Entry from all Notion entries
|
||||||
|
@ -253,9 +250,6 @@ class NotionToJsonl(TextToJsonl):
|
||||||
jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries)
|
jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries)
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
# Compress JSONL formatted Data
|
||||||
if self.config.compressed_jsonl.suffix == ".gz":
|
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
|
||||||
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
|
|
||||||
elif self.config.compressed_jsonl.suffix == ".jsonl":
|
|
||||||
dump_jsonl(jsonl_data, self.config.compressed_jsonl)
|
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Iterable, List
|
||||||
from khoj.processor.org_mode import orgnode
|
from khoj.processor.org_mode import orgnode
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
from khoj.utils.jsonl import compress_jsonl_data
|
||||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ class OrgToJsonl(TextToJsonl):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, previous_entries: List[Entry] = None):
|
def process(self, previous_entries: List[Entry] = []):
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
org_files, org_file_filter, output_file = (
|
org_files, org_file_filter, output_file = (
|
||||||
self.config.input_files,
|
self.config.input_files,
|
||||||
|
@ -51,9 +51,7 @@ class OrgToJsonl(TextToJsonl):
|
||||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
if not previous_entries:
|
with timer("Identify new or updated entries", logger):
|
||||||
entries_with_ids = list(enumerate(current_entries))
|
|
||||||
else:
|
|
||||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||||
current_entries, previous_entries, key="compiled", logger=logger
|
current_entries, previous_entries, key="compiled", logger=logger
|
||||||
)
|
)
|
||||||
|
@ -64,10 +62,7 @@ class OrgToJsonl(TextToJsonl):
|
||||||
jsonl_data = self.convert_org_entries_to_jsonl(entries)
|
jsonl_data = self.convert_org_entries_to_jsonl(entries)
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
# Compress JSONL formatted Data
|
||||||
if output_file.suffix == ".gz":
|
compress_jsonl_data(jsonl_data, output_file)
|
||||||
compress_jsonl_data(jsonl_data, output_file)
|
|
||||||
elif output_file.suffix == ".jsonl":
|
|
||||||
dump_jsonl(jsonl_data, output_file)
|
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
||||||
|
@ -125,9 +120,13 @@ class OrgToJsonl(TextToJsonl):
|
||||||
# Ignore title notes i.e notes with just headings and empty body
|
# Ignore title notes i.e notes with just headings and empty body
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else ""
|
||||||
# Prepend filename as top heading to entry
|
# Prepend filename as top heading to entry
|
||||||
filename = Path(entry_to_file_map[parsed_entry]).stem
|
filename = Path(entry_to_file_map[parsed_entry]).stem
|
||||||
heading = f"* {filename}\n** {parsed_entry.heading}." if parsed_entry.heading else f"* {filename}."
|
if parsed_entry.heading:
|
||||||
|
heading = f"* {filename}\n** {todo_str}{parsed_entry.heading}."
|
||||||
|
else:
|
||||||
|
heading = f"* {filename}."
|
||||||
|
|
||||||
compiled = heading
|
compiled = heading
|
||||||
if state.verbose > 2:
|
if state.verbose > 2:
|
||||||
|
|
|
@ -10,7 +10,7 @@ from langchain.document_loaders import PyPDFLoader
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
from khoj.utils.jsonl import compress_jsonl_data
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class PdfToJsonl(TextToJsonl):
|
class PdfToJsonl(TextToJsonl):
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, previous_entries=None):
|
def process(self, previous_entries=[]):
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
pdf_files, pdf_file_filter, output_file = (
|
pdf_files, pdf_file_filter, output_file = (
|
||||||
self.config.input_files,
|
self.config.input_files,
|
||||||
|
@ -45,12 +45,9 @@ class PdfToJsonl(TextToJsonl):
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
if not previous_entries:
|
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||||
entries_with_ids = list(enumerate(current_entries))
|
current_entries, previous_entries, key="compiled", logger=logger
|
||||||
else:
|
)
|
||||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
|
||||||
current_entries, previous_entries, key="compiled", logger=logger
|
|
||||||
)
|
|
||||||
|
|
||||||
with timer("Write PDF entries to JSONL file", logger):
|
with timer("Write PDF entries to JSONL file", logger):
|
||||||
# Process Each Entry from All Notes Files
|
# Process Each Entry from All Notes Files
|
||||||
|
@ -58,10 +55,7 @@ class PdfToJsonl(TextToJsonl):
|
||||||
jsonl_data = PdfToJsonl.convert_pdf_maps_to_jsonl(entries)
|
jsonl_data = PdfToJsonl.convert_pdf_maps_to_jsonl(entries)
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
# Compress JSONL formatted Data
|
||||||
if output_file.suffix == ".gz":
|
compress_jsonl_data(jsonl_data, output_file)
|
||||||
compress_jsonl_data(jsonl_data, output_file)
|
|
||||||
elif output_file.suffix == ".jsonl":
|
|
||||||
dump_jsonl(jsonl_data, output_file)
|
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ class TextToJsonl(ABC):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process(self, previous_entries: List[Entry] = None) -> List[Tuple[int, Entry]]:
|
def process(self, previous_entries: List[Entry] = []) -> List[Tuple[int, Entry]]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -78,16 +78,23 @@ class TextToJsonl(ABC):
|
||||||
# All entries that exist in both current and previous sets are kept
|
# All entries that exist in both current and previous sets are kept
|
||||||
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
||||||
|
|
||||||
|
# load new entries in the order in which they are processed for a stable sort
|
||||||
|
new_entries = [
|
||||||
|
(current_entry_hashes.index(entry_hash), hash_to_current_entries[entry_hash])
|
||||||
|
for entry_hash in new_entry_hashes
|
||||||
|
]
|
||||||
|
new_entries_sorted = sorted(new_entries, key=lambda e: e[0])
|
||||||
# Mark new entries with -1 id to flag for later embeddings generation
|
# Mark new entries with -1 id to flag for later embeddings generation
|
||||||
new_entries = [(-1, hash_to_current_entries[entry_hash]) for entry_hash in new_entry_hashes]
|
new_entries_sorted = [(-1, entry[1]) for entry in new_entries_sorted]
|
||||||
|
|
||||||
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
||||||
existing_entries = [
|
existing_entries = [
|
||||||
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
||||||
for entry_hash in existing_entry_hashes
|
for entry_hash in existing_entry_hashes
|
||||||
]
|
]
|
||||||
|
|
||||||
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
|
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
|
||||||
entries_with_ids = existing_entries_sorted + new_entries
|
|
||||||
|
entries_with_ids = existing_entries_sorted + new_entries_sorted
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
||||||
|
|
|
@ -5,20 +5,20 @@ import time
|
||||||
import yaml
|
import yaml
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional, Union
|
from typing import Iterable, List, Optional, Union
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter, HTTPException, Header, Request
|
from fastapi import APIRouter, HTTPException, Header, Request
|
||||||
from sentence_transformers import util
|
from sentence_transformers import util
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_content, configure_processor, configure_search
|
from khoj.configure import configure_processor, configure_server
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.utils.config import TextSearchModel
|
from khoj.utils.config import TextSearchModel
|
||||||
from khoj.utils.helpers import log_telemetry, timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
ContentConfig,
|
ContentConfig,
|
||||||
FullConfig,
|
FullConfig,
|
||||||
|
@ -524,34 +524,26 @@ def update(
|
||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
host: Optional[str] = Header(None),
|
host: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
|
if not state.config:
|
||||||
|
error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
|
||||||
|
logger.warning(error_msg)
|
||||||
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
try:
|
try:
|
||||||
state.search_index_lock.acquire()
|
configure_server(state.config, regenerate=force or False, search_type=t)
|
||||||
try:
|
except Exception as e:
|
||||||
if state.config and state.config.search_type:
|
error_msg = f"🚨 Failed to update server via API: {e}"
|
||||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
logger.error(error_msg, exc_info=True)
|
||||||
if state.search_models:
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
state.content_index = configure_content(
|
|
||||||
state.content_index, state.config.content_type, state.search_models, regenerate=force or False, t=t
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
finally:
|
|
||||||
state.search_index_lock.release()
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(e)
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
else:
|
else:
|
||||||
logger.info("📬 Search index updated via API")
|
components = []
|
||||||
|
if state.search_models:
|
||||||
try:
|
components.append("Search models")
|
||||||
if state.config and state.config.processor:
|
if state.content_index:
|
||||||
state.processor_config = configure_processor(state.config.processor)
|
components.append("Content index")
|
||||||
except ValueError as e:
|
if state.processor_config:
|
||||||
logger.error(e)
|
components.append("Conversation processor")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
components_msg = ", ".join(components)
|
||||||
else:
|
logger.info(f"📬 {components_msg} updated via API")
|
||||||
logger.info("📬 Processor reconfigured via API")
|
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
|
|
@ -58,43 +58,48 @@ def extract_entries(jsonl_file) -> List[Entry]:
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False
|
entries_with_ids: List[Tuple[int, Entry]],
|
||||||
|
bi_encoder: BaseEncoder,
|
||||||
|
embeddings_file: Path,
|
||||||
|
regenerate=False,
|
||||||
|
normalize=True,
|
||||||
):
|
):
|
||||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||||
new_entries = []
|
new_embeddings = torch.tensor([], device=state.device)
|
||||||
|
existing_embeddings = torch.tensor([], device=state.device)
|
||||||
|
create_index_msg = ""
|
||||||
# Load pre-computed embeddings from file if exists and update them if required
|
# Load pre-computed embeddings from file if exists and update them if required
|
||||||
if embeddings_file.exists() and not regenerate:
|
if embeddings_file.exists() and not regenerate:
|
||||||
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
corpus_embeddings: torch.Tensor = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||||
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
|
logger.debug(f"Loaded {len(corpus_embeddings)} text embeddings from {embeddings_file}")
|
||||||
|
|
||||||
# Encode any new entries in the corpus and update corpus embeddings
|
|
||||||
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
|
|
||||||
if new_entries:
|
|
||||||
logger.info(f"📩 Indexing {len(new_entries)} text entries.")
|
|
||||||
new_embeddings = bi_encoder.encode(
|
|
||||||
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
|
|
||||||
)
|
|
||||||
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
|
||||||
if existing_entry_ids:
|
|
||||||
existing_embeddings = torch.index_select(
|
|
||||||
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
existing_embeddings = torch.tensor([], device=state.device)
|
|
||||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
|
||||||
# Else compute the corpus embeddings from scratch
|
|
||||||
else:
|
else:
|
||||||
new_entries = [entry.compiled for _, entry in entries_with_ids]
|
corpus_embeddings = torch.tensor([], device=state.device)
|
||||||
logger.info(f"📩 Indexing {len(new_entries)} text entries. Creating index from scratch.")
|
create_index_msg = " Creating index from scratch."
|
||||||
corpus_embeddings = bi_encoder.encode(
|
|
||||||
|
# Encode any new entries in the corpus and update corpus embeddings
|
||||||
|
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
|
||||||
|
if new_entries:
|
||||||
|
logger.info(f"📩 Indexing {len(new_entries)} text entries.{create_index_msg}")
|
||||||
|
new_embeddings = bi_encoder.encode(
|
||||||
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
|
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save regenerated or updated embeddings to file
|
# Extract existing embeddings from previous corpus embeddings
|
||||||
if new_entries:
|
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
||||||
|
if existing_entry_ids:
|
||||||
|
existing_embeddings = torch.index_select(
|
||||||
|
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set corpus embeddings to merger of existing and new embeddings
|
||||||
|
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||||
|
if normalize:
|
||||||
|
# Normalize embeddings for faster lookup via dot product when querying
|
||||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||||
torch.save(corpus_embeddings, embeddings_file)
|
|
||||||
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
|
# Save regenerated or updated embeddings to file
|
||||||
|
torch.save(corpus_embeddings, embeddings_file)
|
||||||
|
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
|
||||||
|
|
||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
@ -173,13 +178,14 @@ def setup(
|
||||||
bi_encoder: BaseEncoder,
|
bi_encoder: BaseEncoder,
|
||||||
regenerate: bool,
|
regenerate: bool,
|
||||||
filters: List[BaseFilter] = [],
|
filters: List[BaseFilter] = [],
|
||||||
|
normalize: bool = True,
|
||||||
) -> TextContent:
|
) -> TextContent:
|
||||||
# Map notes in text files to (compressed) JSONL formatted file
|
# Map notes in text files to (compressed) JSONL formatted file
|
||||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||||
previous_entries = (
|
previous_entries = []
|
||||||
extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
|
if config.compressed_jsonl.exists() and not regenerate:
|
||||||
)
|
previous_entries = extract_entries(config.compressed_jsonl)
|
||||||
entries_with_indices = text_to_jsonl(config).process(previous_entries or [])
|
entries_with_indices = text_to_jsonl(config).process(previous_entries)
|
||||||
|
|
||||||
# Extract Updated Entries
|
# Extract Updated Entries
|
||||||
entries = extract_entries(config.compressed_jsonl)
|
entries = extract_entries(config.compressed_jsonl)
|
||||||
|
@ -190,7 +196,7 @@ def setup(
|
||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||||
corpus_embeddings = compute_embeddings(
|
corpus_embeddings = compute_embeddings(
|
||||||
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate
|
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize
|
||||||
)
|
)
|
||||||
|
|
||||||
for filter in filters:
|
for filter in filters:
|
||||||
|
|
|
@ -20,7 +20,7 @@ def load_jsonl(input_path):
|
||||||
# Open JSONL file
|
# Open JSONL file
|
||||||
if input_path.suffix == ".gz":
|
if input_path.suffix == ".gz":
|
||||||
jsonl_file = gzip.open(get_absolute_path(input_path), "rt", encoding="utf-8")
|
jsonl_file = gzip.open(get_absolute_path(input_path), "rt", encoding="utf-8")
|
||||||
elif input_path.suffix == ".jsonl":
|
else:
|
||||||
jsonl_file = open(get_absolute_path(input_path), "r", encoding="utf-8")
|
jsonl_file = open(get_absolute_path(input_path), "r", encoding="utf-8")
|
||||||
|
|
||||||
# Read JSONL file
|
# Read JSONL file
|
||||||
|
@ -36,17 +36,6 @@ def load_jsonl(input_path):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def dump_jsonl(jsonl_data, output_path):
|
|
||||||
"Write List of JSON objects to JSON line file"
|
|
||||||
# Create output directory, if it doesn't exist
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(jsonl_data)
|
|
||||||
|
|
||||||
logger.debug(f"Wrote jsonl data to {output_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def compress_jsonl_data(jsonl_data, output_path):
|
def compress_jsonl_data(jsonl_data, output_path):
|
||||||
# Create output directory, if it doesn't exist
|
# Create output directory, if it doesn't exist
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
|
@ -24,7 +24,7 @@ host: str = None
|
||||||
port: int = None
|
port: int = None
|
||||||
cli_args: List[str] = None
|
cli_args: List[str] = None
|
||||||
query_cache = LRU()
|
query_cache = LRU()
|
||||||
search_index_lock = threading.Lock()
|
config_lock = threading.Lock()
|
||||||
SearchType = utils_config.SearchType
|
SearchType = utils_config.SearchType
|
||||||
telemetry: List[Dict[str, str]] = []
|
telemetry: List[Dict[str, str]] = []
|
||||||
previous_query: str = None
|
previous_query: str = None
|
||||||
|
|
|
@ -90,7 +90,7 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
|
||||||
content_config.org = TextContentConfig(
|
content_config.org = TextContentConfig(
|
||||||
input_files=None,
|
input_files=None,
|
||||||
input_filter=["tests/data/org/*.org"],
|
input_filter=["tests/data/org/*.org"],
|
||||||
compressed_jsonl=content_dir.joinpath("notes.jsonl"),
|
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
|
||||||
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
|
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
|
||||||
|
|
||||||
content_config.plugins = {
|
content_config.plugins = {
|
||||||
"plugin1": TextContentConfig(
|
"plugin1": TextContentConfig(
|
||||||
input_files=[content_dir.joinpath("notes.jsonl")],
|
input_files=[content_dir.joinpath("notes.jsonl.gz")],
|
||||||
input_filter=None,
|
input_filter=None,
|
||||||
compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
|
compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
|
||||||
embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
|
embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
|
||||||
|
@ -142,7 +142,7 @@ def md_content_config(tmp_path_factory):
|
||||||
content_config.markdown = TextContentConfig(
|
content_config.markdown = TextContentConfig(
|
||||||
input_files=None,
|
input_files=None,
|
||||||
input_filter=["tests/data/markdown/*.markdown"],
|
input_filter=["tests/data/markdown/*.markdown"],
|
||||||
compressed_jsonl=content_dir.joinpath("markdown.jsonl"),
|
compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"),
|
||||||
embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
|
embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import os
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from khoj.utils.config import SearchModels
|
from khoj.utils.config import SearchModels
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
|
@ -17,7 +18,7 @@ from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_asymmetric_setup_with_missing_file_raises_error(
|
def test_text_search_setup_with_missing_file_raises_error(
|
||||||
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
|
@ -32,7 +33,7 @@ def test_asymmetric_setup_with_missing_file_raises_error(
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_asymmetric_setup_with_empty_file_raises_error(
|
def test_text_search_setup_with_empty_file_raises_error(
|
||||||
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
||||||
):
|
):
|
||||||
# Act
|
# Act
|
||||||
|
@ -42,7 +43,7 @@ def test_asymmetric_setup_with_empty_file_raises_error(
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels):
|
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
|
||||||
# Act
|
# Act
|
||||||
# Regenerate notes embeddings during asymmetric setup
|
# Regenerate notes embeddings during asymmetric setup
|
||||||
notes_model = text_search.setup(
|
notes_model = text_search.setup(
|
||||||
|
@ -55,7 +56,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchMo
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog):
|
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
caplog.set_level(logging.INFO, logger="khoj")
|
caplog.set_level(logging.INFO, logger="khoj")
|
||||||
|
|
||||||
|
@ -70,8 +71,8 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
|
||||||
final_logs = caplog.text
|
final_logs = caplog.text
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "📩 Saved computed text embeddings to" in initial_logs
|
assert "Creating index from scratch." in initial_logs
|
||||||
assert "📩 Saved computed text embeddings to" not in final_logs
|
assert "Creating index from scratch." not in final_logs
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@ -121,7 +122,9 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
def test_regenerate_index_with_new_entry(
|
||||||
|
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
|
||||||
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
initial_notes_model = text_search.setup(
|
initial_notes_model = text_search.setup(
|
||||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||||
|
@ -135,25 +138,20 @@ def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchM
|
||||||
with open(new_org_file, "w") as f:
|
with open(new_org_file, "w") as f:
|
||||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||||
|
|
||||||
|
# Act
|
||||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||||
regenerated_notes_model = text_search.setup(
|
regenerated_notes_model = text_search.setup(
|
||||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act
|
|
||||||
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
|
|
||||||
initial_notes_model = text_search.setup(
|
|
||||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(regenerated_notes_model.entries) == 11
|
assert len(regenerated_notes_model.entries) == 11
|
||||||
assert len(regenerated_notes_model.corpus_embeddings) == 11
|
assert len(regenerated_notes_model.corpus_embeddings) == 11
|
||||||
|
|
||||||
# Assert
|
# verify new entry appended to index, without disrupting order or content of existing entries
|
||||||
# verify new entry loaded from updated embeddings, entries
|
error_details = compare_index(initial_notes_model, regenerated_notes_model)
|
||||||
assert len(initial_notes_model.entries) == 11
|
if error_details:
|
||||||
assert len(initial_notes_model.corpus_embeddings) == 11
|
pytest.fail(error_details, False)
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
# reset input_files in config to empty list
|
# reset input_files in config to empty list
|
||||||
|
@ -161,30 +159,101 @@ def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchM
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
def test_update_index_with_duplicate_entries_in_stable_order(
|
||||||
|
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
|
||||||
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
initial_notes_model = text_search.setup(
|
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
|
||||||
|
# Insert org-mode entries with same compiled form into new org file
|
||||||
|
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||||
|
with open(new_file_to_index, "w") as f:
|
||||||
|
f.write(f"{new_entry}{new_entry}")
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# load embeddings, entries, notes model after adding new org-mode file
|
||||||
|
initial_index = text_search.setup(
|
||||||
|
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(initial_notes_model.entries) == 10
|
# update embeddings, entries, notes model after adding new org-mode file
|
||||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
updated_index = text_search.setup(
|
||||||
|
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# verify only 1 entry added even if there are multiple duplicate entries
|
||||||
|
assert len(initial_index.entries) == len(updated_index.entries) == 1
|
||||||
|
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1
|
||||||
|
|
||||||
|
# verify the same entry is added even when there are multiple duplicate entries
|
||||||
|
error_details = compare_index(initial_index, updated_index)
|
||||||
|
if error_details:
|
||||||
|
pytest.fail(error_details)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
||||||
|
# Arrange
|
||||||
|
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||||
|
|
||||||
|
# Insert org-mode entries with same compiled form into new org file
|
||||||
|
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||||
|
with open(new_file_to_index, "w") as f:
|
||||||
|
f.write(f"{new_entry}{new_entry} -- Tatooine")
|
||||||
|
|
||||||
|
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||||
|
initial_index = text_search.setup(
|
||||||
|
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# update embeddings, entries, notes model after removing an entry from the org file
|
||||||
|
with open(new_file_to_index, "w") as f:
|
||||||
|
f.write(f"{new_entry}")
|
||||||
|
|
||||||
|
# Act
|
||||||
|
updated_index = text_search.setup(
|
||||||
|
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# verify only 1 entry added even if there are multiple duplicate entries
|
||||||
|
assert len(initial_index.entries) == len(updated_index.entries) + 1
|
||||||
|
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
|
||||||
|
|
||||||
|
# verify the same entry is added even when there are multiple duplicate entries
|
||||||
|
error_details = compare_index(updated_index, initial_index)
|
||||||
|
if error_details:
|
||||||
|
pytest.fail(error_details)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||||
|
# Arrange
|
||||||
|
initial_notes_model = text_search.setup(
|
||||||
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
|
||||||
|
)
|
||||||
|
|
||||||
# append org-mode entry to first org input file in config
|
# append org-mode entry to first org input file in config
|
||||||
with open(new_org_file, "w") as f:
|
with open(new_org_file, "w") as f:
|
||||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||||
|
f.write(new_entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# update embeddings, entries with the newly added note
|
# update embeddings, entries with the newly added note
|
||||||
content_config.org.input_files = [f"{new_org_file}"]
|
content_config.org.input_files = [f"{new_org_file}"]
|
||||||
initial_notes_model = text_search.setup(
|
final_notes_model = text_search.setup(
|
||||||
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify new entry added in updated embeddings, entries
|
assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1
|
||||||
assert len(initial_notes_model.entries) == 11
|
assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1
|
||||||
assert len(initial_notes_model.corpus_embeddings) == 11
|
|
||||||
|
# verify new entry appended to index, without disrupting order or content of existing entries
|
||||||
|
error_details = compare_index(initial_notes_model, final_notes_model)
|
||||||
|
if error_details:
|
||||||
|
pytest.fail(error_details, False)
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
# reset input_files in config to empty list
|
# reset input_files in config to empty list
|
||||||
|
@ -202,3 +271,25 @@ def test_asymmetric_setup_github(content_config: ContentConfig, search_models: S
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(github_model.entries) > 1
|
assert len(github_model.entries) > 1
|
||||||
|
|
||||||
|
|
||||||
|
def compare_index(initial_notes_model, final_notes_model):
|
||||||
|
mismatched_entries, mismatched_embeddings = [], []
|
||||||
|
for index in range(len(initial_notes_model.entries)):
|
||||||
|
if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
|
||||||
|
mismatched_entries.append(index)
|
||||||
|
|
||||||
|
# verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
|
||||||
|
for index in range(len(initial_notes_model.corpus_embeddings)):
|
||||||
|
if not torch.equal(final_notes_model.corpus_embeddings[index], initial_notes_model.corpus_embeddings[index]):
|
||||||
|
mismatched_embeddings.append(index)
|
||||||
|
|
||||||
|
error_details = ""
|
||||||
|
if mismatched_entries:
|
||||||
|
mismatched_entries_str = ",".join(map(str, mismatched_entries))
|
||||||
|
error_details += f"Entries at {mismatched_entries_str} not equal\n"
|
||||||
|
if mismatched_embeddings:
|
||||||
|
mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
|
||||||
|
error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
|
||||||
|
|
||||||
|
return error_details
|
||||||
|
|
Loading…
Reference in a new issue