mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Improve Indexing Text Entries (#535)
Major - Ensure search results logic consistent across migration to DB, multi-user - Manually verified search results for sample queries look the same across migration - Flatten indexing code for better indexing progress tracking and code readability Minor -a4f407f
Test memory leak on MPS device when generating vector embeddings -ef24485
Improve Khoj with DB setup instructions in the Django app readme (for now) -f212cc7
Arrange remaining text search tests in arrange, act, assert order -022017d
Fix text search tests to test updated indexing log messages
This commit is contained in:
commit
38f24a037d
11 changed files with 199 additions and 134 deletions
|
@ -93,6 +93,7 @@ test = [
|
|||
"factory-boy >= 3.2.1",
|
||||
"trio >= 0.22.0",
|
||||
"pytest-xdist",
|
||||
"psutil >= 5.8.0",
|
||||
]
|
||||
dev = [
|
||||
"khoj-assistant[test]",
|
||||
|
|
|
@ -17,16 +17,26 @@ docker-compose up
|
|||
|
||||
## Setup (Local)
|
||||
|
||||
### Install dependencies
|
||||
### Install Postgres (with PgVector)
|
||||
|
||||
#### MacOS
|
||||
- Install the [Postgres.app](https://postgresapp.com/).
|
||||
|
||||
#### Debian, Ubuntu
|
||||
From [official instructions](https://wiki.postgresql.org/wiki/Apt)
|
||||
|
||||
```bash
|
||||
pip install -e '.[dev]'
|
||||
sudo apt install -y postgresql-common
|
||||
sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh
|
||||
sudo apt install postgres-16 postgresql-16-pgvector
|
||||
```
|
||||
|
||||
### Setup the database
|
||||
#### Windows
|
||||
- Use the [recommended installer](https://www.postgresql.org/download/windows/)
|
||||
|
||||
1. Ensure you have Postgres installed. For MacOS, you can use [Postgres.app](https://postgresapp.com/).
|
||||
2. If you're not using Postgres.app, you may have to install the pgvector extension manually. You can find the instructions [here](https://github.com/pgvector/pgvector#installation). If you're using Postgres.app, you can skip this step. Reproduced instructions below for convenience.
|
||||
#### From Source
|
||||
1. Follow instructions to [Install Postgres](https://www.postgresql.org/download/)
|
||||
2. Follow instructions to [Install PgVector](https://github.com/pgvector/pgvector#installation) in case you need to manually install it. Reproduced instructions below for convenience.
|
||||
|
||||
```bash
|
||||
cd /tmp
|
||||
|
@ -35,32 +45,50 @@ cd pgvector
|
|||
make
|
||||
make install # may need sudo
|
||||
```
|
||||
3. Create a database
|
||||
|
||||
### Create the khoj database
|
||||
### Create the Khoj database
|
||||
|
||||
#### MacOS
|
||||
```bash
|
||||
createdb khoj -U postgres
|
||||
```
|
||||
|
||||
### Make migrations
|
||||
#### Debian, Ubuntu
|
||||
```bash
|
||||
sudo -u postgres createdb khoj
|
||||
```
|
||||
|
||||
This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted).
|
||||
- [Optional] To set default postgres user's password
|
||||
- Execute `ALTER USER postgres PASSWORD 'my_secure_password';` using `psql`
|
||||
- Run `export $POSTGRES_PASSWORD=my_secure_password` in your terminal for Khoj to use it later
|
||||
|
||||
### Install Khoj
|
||||
|
||||
```bash
|
||||
pip install -e '.[dev]'
|
||||
```
|
||||
|
||||
### Make Khoj DB migrations
|
||||
|
||||
This command will create the migrations for the database app. This command should be run whenever a new db model is added to the database app or an existing db model is modified (updated or deleted).
|
||||
|
||||
```bash
|
||||
python3 src/manage.py makemigrations
|
||||
```
|
||||
|
||||
### Run migrations
|
||||
### Run Khoj DB migrations
|
||||
|
||||
This command will run any pending migrations in your application.
|
||||
```bash
|
||||
python3 src/manage.py migrate
|
||||
```
|
||||
|
||||
### Run the server
|
||||
### Start Khoj Server
|
||||
|
||||
While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend.
|
||||
|
||||
*Note: Anonymous mode bypasses authentication for local, single-user usage.*
|
||||
|
||||
```bash
|
||||
python3 src/khoj/main.py
|
||||
python3 src/khoj/main.py --anonymous-mode
|
||||
```
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import secrets
|
||||
from typing import Type, TypeVar, List
|
||||
from datetime import date
|
||||
import secrets
|
||||
|
@ -36,9 +35,6 @@ from database.models import (
|
|||
OfflineChatProcessorConversationConfig,
|
||||
)
|
||||
from khoj.utils.helpers import generate_random_name
|
||||
from khoj.utils.rawconfig import (
|
||||
ConversationProcessorConfig as UserConversationProcessorConfig,
|
||||
)
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from typing import List
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from sentence_transformers import CrossEncoder
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
|
||||
from khoj.utils.helpers import get_device
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
|
@ -9,18 +8,16 @@ from khoj.utils.rawconfig import SearchResponse
|
|||
|
||||
class EmbeddingsModel:
|
||||
def __init__(self):
|
||||
self.encode_kwargs = {"normalize_embeddings": True}
|
||||
self.model_kwargs = {"device": get_device()}
|
||||
self.model_name = "thenlper/gte-small"
|
||||
encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True}
|
||||
model_kwargs = {"device": get_device()}
|
||||
self.embeddings_model = HuggingFaceEmbeddings(
|
||||
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
|
||||
)
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
|
||||
def embed_query(self, query):
|
||||
return self.embeddings_model.embed_query(query)
|
||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||
|
||||
def embed_documents(self, docs):
|
||||
return self.embeddings_model.embed_documents(docs)
|
||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||
|
||||
|
||||
class CrossEncoderModel:
|
||||
|
|
|
@ -24,7 +24,7 @@ class OrgToEntries(TextToEntries):
|
|||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
index_heading_entries = True
|
||||
index_heading_entries = False
|
||||
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
import hashlib
|
||||
from itertools import repeat
|
||||
import logging
|
||||
import uuid
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, List, Tuple, Set, Any
|
||||
from khoj.utils.helpers import timer, batcher
|
||||
from khoj.utils.helpers import is_none_or_empty, timer, batcher
|
||||
|
||||
|
||||
# Internal Packages
|
||||
|
@ -83,92 +84,88 @@ class TextToEntries(ABC):
|
|||
user: KhojUser = None,
|
||||
regenerate: bool = False,
|
||||
):
|
||||
with timer("Construct current entry hashes", logger):
|
||||
with timer("Constructed current entry hashes in", logger):
|
||||
hashes_by_file = dict[str, set[str]]()
|
||||
current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries))
|
||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||
for entry in tqdm(current_entries, desc="Hashing Entries"):
|
||||
hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry))
|
||||
|
||||
num_deleted_embeddings = 0
|
||||
with timer("Preparing dataset for regeneration", logger):
|
||||
if regenerate:
|
||||
logger.debug(f"Deleting all embeddings for file type {file_type}")
|
||||
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type)
|
||||
num_deleted_entries = 0
|
||||
if regenerate:
|
||||
with timer("Prepared dataset for regeneration in", logger):
|
||||
logger.debug(f"Deleting all entries for file type {file_type}")
|
||||
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
|
||||
|
||||
num_new_embeddings = 0
|
||||
with timer("Identify hashes for adding new entries", logger):
|
||||
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
|
||||
hashes_to_process = set()
|
||||
with timer("Identified entries to add to database in", logger):
|
||||
for file in tqdm(hashes_by_file, desc="Identify new entries"):
|
||||
hashes_for_file = hashes_by_file[file]
|
||||
hashes_to_process = set()
|
||||
existing_entries = DbEntry.objects.filter(
|
||||
user=user, hashed_value__in=hashes_for_file, file_type=file_type
|
||||
)
|
||||
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
||||
hashes_to_process = hashes_for_file - existing_entry_hashes
|
||||
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
||||
|
||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||
embeddings = self.embeddings_model.embed_documents(data_to_embed)
|
||||
embeddings = []
|
||||
with timer("Generated embeddings for entries to add to database in", logger):
|
||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||
embeddings += self.embeddings_model.embed_documents(data_to_embed)
|
||||
|
||||
with timer("Update the database with new vector embeddings", logger):
|
||||
num_items = len(hashes_to_process)
|
||||
assert num_items == len(embeddings)
|
||||
batch_size = min(200, num_items)
|
||||
entry_batches = zip(hashes_to_process, embeddings)
|
||||
added_entries: list[DbEntry] = []
|
||||
with timer("Added entries to database in", logger):
|
||||
num_items = len(hashes_to_process)
|
||||
assert num_items == len(embeddings)
|
||||
batch_size = min(200, num_items)
|
||||
entry_batches = zip(hashes_to_process, embeddings)
|
||||
|
||||
for entry_batch in tqdm(
|
||||
batcher(entry_batches, batch_size), desc="Processing embeddings in batches"
|
||||
):
|
||||
batch_embeddings_to_create = []
|
||||
for entry_hash, new_entry in entry_batch:
|
||||
entry = hash_to_current_entries[entry_hash]
|
||||
batch_embeddings_to_create.append(
|
||||
DbEntry(
|
||||
user=user,
|
||||
embeddings=new_entry,
|
||||
raw=entry.raw,
|
||||
compiled=entry.compiled,
|
||||
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
||||
file_path=entry.file,
|
||||
file_type=file_type,
|
||||
hashed_value=entry_hash,
|
||||
corpus_id=entry.corpus_id,
|
||||
)
|
||||
)
|
||||
new_entries = DbEntry.objects.bulk_create(batch_embeddings_to_create)
|
||||
logger.debug(f"Created {len(new_entries)} new embeddings")
|
||||
num_new_embeddings += len(new_entries)
|
||||
for entry_batch in tqdm(batcher(entry_batches, batch_size), desc="Add entries to database"):
|
||||
batch_embeddings_to_create = []
|
||||
for entry_hash, new_entry in entry_batch:
|
||||
entry = hash_to_current_entries[entry_hash]
|
||||
batch_embeddings_to_create.append(
|
||||
DbEntry(
|
||||
user=user,
|
||||
embeddings=new_entry,
|
||||
raw=entry.raw,
|
||||
compiled=entry.compiled,
|
||||
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
||||
file_path=entry.file,
|
||||
file_type=file_type,
|
||||
hashed_value=entry_hash,
|
||||
corpus_id=entry.corpus_id,
|
||||
)
|
||||
)
|
||||
added_entries += DbEntry.objects.bulk_create(batch_embeddings_to_create)
|
||||
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
|
||||
|
||||
dates_to_create = []
|
||||
with timer("Create new date associations for new embeddings", logger):
|
||||
for new_entry in new_entries:
|
||||
dates = self.date_filter.extract_dates(new_entry.raw)
|
||||
for date in dates:
|
||||
dates_to_create.append(
|
||||
EntryDates(
|
||||
date=date,
|
||||
entry=new_entry,
|
||||
)
|
||||
)
|
||||
new_dates = EntryDates.objects.bulk_create(dates_to_create)
|
||||
if len(new_dates) > 0:
|
||||
logger.debug(f"Created {len(new_dates)} new date entries")
|
||||
new_dates = []
|
||||
with timer("Indexed dates from added entries in", logger):
|
||||
for added_entry in added_entries:
|
||||
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
|
||||
dates_to_create = [
|
||||
EntryDates(date=date, entry=added_entry)
|
||||
for date, added_entry in dates_in_entries
|
||||
if not is_none_or_empty(date)
|
||||
]
|
||||
new_dates += EntryDates.objects.bulk_create(dates_to_create)
|
||||
logger.debug(f"Indexed {len(new_dates)} dates from added {file_type} entries")
|
||||
|
||||
with timer("Identify hashes for removed entries", logger):
|
||||
with timer("Deleted entries identified by server from database in", logger):
|
||||
for file in hashes_by_file:
|
||||
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
|
||||
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
|
||||
num_deleted_embeddings += len(to_delete_entry_hashes)
|
||||
num_deleted_entries += len(to_delete_entry_hashes)
|
||||
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
||||
|
||||
with timer("Identify hashes for deleting entries", logger):
|
||||
with timer("Deleted entries requested by clients from database in", logger):
|
||||
if deletion_filenames is not None:
|
||||
for file_path in deletion_filenames:
|
||||
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
||||
num_deleted_embeddings += deleted_count
|
||||
num_deleted_entries += deleted_count
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
return len(added_entries), num_deleted_entries
|
||||
|
||||
@staticmethod
|
||||
def mark_entries_for_update(
|
||||
|
|
|
@ -321,7 +321,6 @@ def load_content(
|
|||
content_index: Optional[ContentIndex],
|
||||
search_models: SearchModels,
|
||||
):
|
||||
logger.info(f"Loading content from existing embeddings...")
|
||||
if content_config is None:
|
||||
logger.warning("🚨 No Content configuration available.")
|
||||
return None
|
||||
|
|
|
@ -207,7 +207,7 @@ def setup(
|
|||
file_names = [file_name for file_name in files]
|
||||
|
||||
logger.info(
|
||||
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
|
||||
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -51,7 +51,8 @@ def cli(args=None):
|
|||
|
||||
args, remaining_args = parser.parse_known_args(args)
|
||||
|
||||
logger.debug(f"Ignoring unknown commandline args: {remaining_args}")
|
||||
if len(remaining_args) > 0:
|
||||
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
|
||||
|
||||
# Set default values for arguments
|
||||
args.chat_on_gpu = not args.disable_chat_on_gpu
|
||||
|
|
|
@ -1,3 +1,14 @@
|
|||
# Standard Packages
|
||||
import numpy as np
|
||||
import psutil
|
||||
from scipy.stats import linregress
|
||||
import secrets
|
||||
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.embeddings import EmbeddingsModel
|
||||
from khoj.utils import helpers
|
||||
|
||||
|
||||
|
@ -44,3 +55,29 @@ def test_lru_cache():
|
|||
cache["b"] # accessing 'b' makes it the most recently used item
|
||||
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
|
||||
assert cache == {"b": 2, "d": 4}
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Memory leak exists on GPU, MPS devices")
|
||||
def test_encode_docs_memory_leak():
|
||||
# Arrange
|
||||
iterations = 50
|
||||
batch_size = 20
|
||||
embeddings_model = EmbeddingsModel()
|
||||
memory_usage_trend = []
|
||||
|
||||
# Act
|
||||
# Encode random strings repeatedly and record memory usage trend
|
||||
for iteration in range(iterations):
|
||||
random_docs = [" ".join(secrets.token_hex(5) for _ in range(10)) for _ in range(batch_size)]
|
||||
a = [embeddings_model.embed_documents(random_docs)]
|
||||
memory_usage_trend += [psutil.Process().memory_info().rss / (1024 * 1024)]
|
||||
print(f"{iteration:02d}, {memory_usage_trend[-1]:.2f}", flush=True)
|
||||
|
||||
# Calculate slope of line fitting memory usage history
|
||||
memory_usage_trend = np.array(memory_usage_trend)
|
||||
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
|
||||
|
||||
# Assert
|
||||
# If slope is positive memory utilization is increasing
|
||||
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
|
||||
assert slope < 2, f"Memory usage increasing at ~{slope:.2f} MB per iteration"
|
||||
|
|
|
@ -48,10 +48,11 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
|
|||
user=default_user,
|
||||
)
|
||||
|
||||
# Act
|
||||
org_files = collect_files(user=default_user)["org"]
|
||||
|
||||
# Act
|
||||
# should not raise IsADirectoryError and return orgfile
|
||||
# Assert
|
||||
# should return orgfile and not raise IsADirectoryError
|
||||
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
|
||||
|
||||
|
||||
|
@ -62,12 +63,14 @@ def test_text_search_setup_with_empty_file_raises_error(
|
|||
):
|
||||
# Arrange
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
||||
# Assert
|
||||
assert "Deleted 3 entries. Created 0 new entries for user " in caplog.records[-1].message
|
||||
verify_embeddings(0, default_user)
|
||||
|
||||
|
||||
|
@ -79,12 +82,15 @@ def test_text_indexer_deletes_embedding_before_regenerate(
|
|||
# Arrange
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert "Deleting all embeddings for file type org" in caplog.text
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
||||
assert "Deleting all entries for file type org" in caplog.text
|
||||
assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -93,13 +99,14 @@ def test_text_search_setup_batch_processes(content_config: ContentConfig, defaul
|
|||
# Arrange
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert "Created 4 new embeddings" in caplog.text
|
||||
assert "Created 6 new embeddings" in caplog.text
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
||||
assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -122,8 +129,8 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
|
|||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
assert "Deleting all embeddings for file type org" in initial_logs
|
||||
assert "Deleting all embeddings for file type org" not in final_logs
|
||||
assert "Deleting all entries for file type org" in initial_logs
|
||||
assert "Deleting all entries for file type org" not in final_logs
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -135,7 +142,6 @@ async def test_text_search(search_config: SearchConfig):
|
|||
default_user = await KhojUser.objects.acreate(
|
||||
username="test_user", password="test_password", email="test@example.com"
|
||||
)
|
||||
# Arrange
|
||||
org_config = await LocalOrgConfig.objects.acreate(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/org/*.org"],
|
||||
|
@ -159,13 +165,12 @@ async def test_text_search(search_config: SearchConfig):
|
|||
|
||||
# Act
|
||||
hits = await text_search.query(default_user, query)
|
||||
|
||||
# Assert
|
||||
results = text_search.collate_results(hits)
|
||||
results = sorted(results, key=lambda x: float(x.score))[:1]
|
||||
# search results should contain "git clone" entry
|
||||
|
||||
# Assert
|
||||
search_result = results[0].entry
|
||||
assert "git clone" in search_result
|
||||
assert "git clone" in search_result, 'search result did not contain "git clone" entry'
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -188,8 +193,9 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
|
|||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
|
||||
# Assert
|
||||
# verify newly added org-mode entry is split by max tokens
|
||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
|
||||
assert (
|
||||
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||
), "new entry not split by max tokens"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -245,8 +251,9 @@ conda activate khoj
|
|||
)
|
||||
|
||||
# Assert
|
||||
# verify newly added org-mode entry is split by max tokens
|
||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
|
||||
assert (
|
||||
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||
), "new entry not split by max tokens"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -256,27 +263,29 @@ def test_regenerate_index_with_new_entry(
|
|||
):
|
||||
# Arrange
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
||||
initial_data = get_org_files(org_config)
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
org_config.input_files = [f"{new_org_file}"]
|
||||
with open(new_org_file, "w") as f:
|
||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||
|
||||
data = get_org_files(org_config)
|
||||
final_data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message
|
||||
assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
|
||||
assert "Deleted 10 entries. Created 11 new entries for user " in final_logs
|
||||
verify_embeddings(11, default_user)
|
||||
|
||||
|
||||
|
@ -311,8 +320,8 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
|||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in initial_logs
|
||||
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in final_logs
|
||||
assert "Deleted 3 entries. Created 1 new entries for user " in initial_logs
|
||||
assert "Deleted 0 entries. Created 0 new entries for user " in final_logs
|
||||
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
@ -327,29 +336,29 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
|
|||
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}{new_entry} -- Tatooine")
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
initial_data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# update embeddings, entries, notes model after removing an entry from the org file
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}")
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
final_data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in initial_logs
|
||||
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in final_logs
|
||||
assert "Deleted 3 entries. Created 2 new entries for user " in initial_logs
|
||||
assert "Deleted 1 entries. Created 0 new entries for user " in final_logs
|
||||
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
@ -379,9 +388,8 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
|||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in initial_logs
|
||||
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in final_logs
|
||||
|
||||
assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
|
||||
assert "Deleted 0 entries. Created 1 new entries for user " in final_logs
|
||||
verify_embeddings(11, default_user)
|
||||
|
||||
|
||||
|
@ -390,6 +398,7 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
|||
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
|
||||
# Arrange
|
||||
github_config = GithubConfig.objects.filter(user=default_user).first()
|
||||
|
||||
# Act
|
||||
# Regenerate github embeddings to test asymmetric setup without caching
|
||||
text_search.setup(
|
||||
|
|
Loading…
Add table
Reference in a new issue