Improve Indexing Text Entries (#535)

Major
- Ensure search results logic consistent across migration to DB, multi-user
- Manually verified search results for sample queries look the same across migration
 - Flatten indexing code for better indexing progress tracking and code readability

Minor
- a4f407f Test memory leak on MPS device when generating vector embeddings
- ef24485 Improve Khoj with DB setup instructions in the Django app readme (for now)
- f212cc7 Arrange remaining text search tests in arrange, act, assert order
- 022017d Fix text search tests to test updated indexing log messages
This commit is contained in:
Debanjum 2023-11-06 16:01:53 -08:00 committed by GitHub
commit 38f24a037d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 199 additions and 134 deletions

View file

@ -93,6 +93,7 @@ test = [
"factory-boy >= 3.2.1",
"trio >= 0.22.0",
"pytest-xdist",
"psutil >= 5.8.0",
]
dev = [
"khoj-assistant[test]",

View file

@ -17,16 +17,26 @@ docker-compose up
## Setup (Local)
### Install dependencies
### Install Postgres (with PgVector)
#### MacOS
- Install the [Postgres.app](https://postgresapp.com/).
#### Debian, Ubuntu
From [official instructions](https://wiki.postgresql.org/wiki/Apt)
```bash
pip install -e '.[dev]'
sudo apt install -y postgresql-common
sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh
sudo apt install postgres-16 postgresql-16-pgvector
```
### Setup the database
#### Windows
- Use the [recommended installer](https://www.postgresql.org/download/windows/)
1. Ensure you have Postgres installed. For MacOS, you can use [Postgres.app](https://postgresapp.com/).
2. If you're not using Postgres.app, you may have to install the pgvector extension manually. You can find the instructions [here](https://github.com/pgvector/pgvector#installation). If you're using Postgres.app, you can skip this step. Reproduced instructions below for convenience.
#### From Source
1. Follow instructions to [Install Postgres](https://www.postgresql.org/download/)
2. Follow instructions to [Install PgVector](https://github.com/pgvector/pgvector#installation) in case you need to manually install it. Reproduced instructions below for convenience.
```bash
cd /tmp
@ -35,32 +45,50 @@ cd pgvector
make
make install # may need sudo
```
3. Create a database
### Create the khoj database
### Create the Khoj database
#### MacOS
```bash
createdb khoj -U postgres
```
### Make migrations
#### Debian, Ubuntu
```bash
sudo -u postgres createdb khoj
```
This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted).
- [Optional] To set default postgres user's password
- Execute `ALTER USER postgres PASSWORD 'my_secure_password';` using `psql`
- Run `export $POSTGRES_PASSWORD=my_secure_password` in your terminal for Khoj to use it later
### Install Khoj
```bash
pip install -e '.[dev]'
```
### Make Khoj DB migrations
This command will create the migrations for the database app. This command should be run whenever a new db model is added to the database app or an existing db model is modified (updated or deleted).
```bash
python3 src/manage.py makemigrations
```
### Run migrations
### Run Khoj DB migrations
This command will run any pending migrations in your application.
```bash
python3 src/manage.py migrate
```
### Run the server
### Start Khoj Server
While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend.
*Note: Anonymous mode bypasses authentication for local, single-user usage.*
```bash
python3 src/khoj/main.py
python3 src/khoj/main.py --anonymous-mode
```

View file

@ -1,4 +1,3 @@
import secrets
from typing import Type, TypeVar, List
from datetime import date
import secrets
@ -36,9 +35,6 @@ from database.models import (
OfflineChatProcessorConversationConfig,
)
from khoj.utils.helpers import generate_random_name
from khoj.utils.rawconfig import (
ConversationProcessorConfig as UserConversationProcessorConfig,
)
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.date_filter import DateFilter

View file

@ -1,7 +1,6 @@
from typing import List
from langchain.embeddings import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
from sentence_transformers import SentenceTransformer, CrossEncoder
from khoj.utils.helpers import get_device
from khoj.utils.rawconfig import SearchResponse
@ -9,18 +8,16 @@ from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel:
def __init__(self):
self.encode_kwargs = {"normalize_embeddings": True}
self.model_kwargs = {"device": get_device()}
self.model_name = "thenlper/gte-small"
encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True}
model_kwargs = {"device": get_device()}
self.embeddings_model = HuggingFaceEmbeddings(
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
)
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def embed_query(self, query):
return self.embeddings_model.embed_query(query)
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
def embed_documents(self, docs):
return self.embeddings_model.embed_documents(docs)
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
class CrossEncoderModel:

View file

@ -24,7 +24,7 @@ class OrgToEntries(TextToEntries):
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
index_heading_entries = True
index_heading_entries = False
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])

View file

@ -1,11 +1,12 @@
# Standard Packages
from abc import ABC, abstractmethod
import hashlib
from itertools import repeat
import logging
import uuid
from tqdm import tqdm
from typing import Callable, List, Tuple, Set, Any
from khoj.utils.helpers import timer, batcher
from khoj.utils.helpers import is_none_or_empty, timer, batcher
# Internal Packages
@ -83,92 +84,88 @@ class TextToEntries(ABC):
user: KhojUser = None,
regenerate: bool = False,
):
with timer("Construct current entry hashes", logger):
with timer("Constructed current entry hashes in", logger):
hashes_by_file = dict[str, set[str]]()
current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries))
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
for entry in tqdm(current_entries, desc="Hashing Entries"):
hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry))
num_deleted_embeddings = 0
with timer("Preparing dataset for regeneration", logger):
if regenerate:
logger.debug(f"Deleting all embeddings for file type {file_type}")
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type)
num_deleted_entries = 0
if regenerate:
with timer("Prepared dataset for regeneration in", logger):
logger.debug(f"Deleting all entries for file type {file_type}")
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
num_new_embeddings = 0
with timer("Identify hashes for adding new entries", logger):
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
hashes_to_process = set()
with timer("Identified entries to add to database in", logger):
for file in tqdm(hashes_by_file, desc="Identify new entries"):
hashes_for_file = hashes_by_file[file]
hashes_to_process = set()
existing_entries = DbEntry.objects.filter(
user=user, hashed_value__in=hashes_for_file, file_type=file_type
)
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
hashes_to_process = hashes_for_file - existing_entry_hashes
hashes_to_process |= hashes_for_file - existing_entry_hashes
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings = self.embeddings_model.embed_documents(data_to_embed)
embeddings = []
with timer("Generated embeddings for entries to add to database in", logger):
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings += self.embeddings_model.embed_documents(data_to_embed)
with timer("Update the database with new vector embeddings", logger):
num_items = len(hashes_to_process)
assert num_items == len(embeddings)
batch_size = min(200, num_items)
entry_batches = zip(hashes_to_process, embeddings)
added_entries: list[DbEntry] = []
with timer("Added entries to database in", logger):
num_items = len(hashes_to_process)
assert num_items == len(embeddings)
batch_size = min(200, num_items)
entry_batches = zip(hashes_to_process, embeddings)
for entry_batch in tqdm(
batcher(entry_batches, batch_size), desc="Processing embeddings in batches"
):
batch_embeddings_to_create = []
for entry_hash, new_entry in entry_batch:
entry = hash_to_current_entries[entry_hash]
batch_embeddings_to_create.append(
DbEntry(
user=user,
embeddings=new_entry,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading[:1000], # Truncate to max chars of field allowed
file_path=entry.file,
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,
)
)
new_entries = DbEntry.objects.bulk_create(batch_embeddings_to_create)
logger.debug(f"Created {len(new_entries)} new embeddings")
num_new_embeddings += len(new_entries)
for entry_batch in tqdm(batcher(entry_batches, batch_size), desc="Add entries to database"):
batch_embeddings_to_create = []
for entry_hash, new_entry in entry_batch:
entry = hash_to_current_entries[entry_hash]
batch_embeddings_to_create.append(
DbEntry(
user=user,
embeddings=new_entry,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading[:1000], # Truncate to max chars of field allowed
file_path=entry.file,
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,
)
)
added_entries += DbEntry.objects.bulk_create(batch_embeddings_to_create)
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
dates_to_create = []
with timer("Create new date associations for new embeddings", logger):
for new_entry in new_entries:
dates = self.date_filter.extract_dates(new_entry.raw)
for date in dates:
dates_to_create.append(
EntryDates(
date=date,
entry=new_entry,
)
)
new_dates = EntryDates.objects.bulk_create(dates_to_create)
if len(new_dates) > 0:
logger.debug(f"Created {len(new_dates)} new date entries")
new_dates = []
with timer("Indexed dates from added entries in", logger):
for added_entry in added_entries:
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
dates_to_create = [
EntryDates(date=date, entry=added_entry)
for date, added_entry in dates_in_entries
if not is_none_or_empty(date)
]
new_dates += EntryDates.objects.bulk_create(dates_to_create)
logger.debug(f"Indexed {len(new_dates)} dates from added {file_type} entries")
with timer("Identify hashes for removed entries", logger):
with timer("Deleted entries identified by server from database in", logger):
for file in hashes_by_file:
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
num_deleted_embeddings += len(to_delete_entry_hashes)
num_deleted_entries += len(to_delete_entry_hashes)
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
with timer("Identify hashes for deleting entries", logger):
with timer("Deleted entries requested by clients from database in", logger):
if deletion_filenames is not None:
for file_path in deletion_filenames:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_embeddings += deleted_count
num_deleted_entries += deleted_count
return num_new_embeddings, num_deleted_embeddings
return len(added_entries), num_deleted_entries
@staticmethod
def mark_entries_for_update(

View file

@ -321,7 +321,6 @@ def load_content(
content_index: Optional[ContentIndex],
search_models: SearchModels,
):
logger.info(f"Loading content from existing embeddings...")
if content_config is None:
logger.warning("🚨 No Content configuration available.")
return None

View file

@ -207,7 +207,7 @@ def setup(
file_names = [file_name for file_name in files]
logger.info(
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
)

View file

@ -51,7 +51,8 @@ def cli(args=None):
args, remaining_args = parser.parse_known_args(args)
logger.debug(f"Ignoring unknown commandline args: {remaining_args}")
if len(remaining_args) > 0:
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
# Set default values for arguments
args.chat_on_gpu = not args.disable_chat_on_gpu

View file

@ -1,3 +1,14 @@
# Standard Packages
import numpy as np
import psutil
from scipy.stats import linregress
import secrets
# External Packages
import pytest
# Internal Packages
from khoj.processor.embeddings import EmbeddingsModel
from khoj.utils import helpers
@ -44,3 +55,29 @@ def test_lru_cache():
cache["b"] # accessing 'b' makes it the most recently used item
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
assert cache == {"b": 2, "d": 4}
@pytest.mark.skip(reason="Memory leak exists on GPU, MPS devices")
def test_encode_docs_memory_leak():
# Arrange
iterations = 50
batch_size = 20
embeddings_model = EmbeddingsModel()
memory_usage_trend = []
# Act
# Encode random strings repeatedly and record memory usage trend
for iteration in range(iterations):
random_docs = [" ".join(secrets.token_hex(5) for _ in range(10)) for _ in range(batch_size)]
a = [embeddings_model.embed_documents(random_docs)]
memory_usage_trend += [psutil.Process().memory_info().rss / (1024 * 1024)]
print(f"{iteration:02d}, {memory_usage_trend[-1]:.2f}", flush=True)
# Calculate slope of line fitting memory usage history
memory_usage_trend = np.array(memory_usage_trend)
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
# Assert
# If slope is positive memory utilization is increasing
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
assert slope < 2, f"Memory usage increasing at ~{slope:.2f} MB per iteration"

View file

@ -48,10 +48,11 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
user=default_user,
)
# Act
org_files = collect_files(user=default_user)["org"]
# Act
# should not raise IsADirectoryError and return orgfile
# Assert
# should return orgfile and not raise IsADirectoryError
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
@ -62,12 +63,14 @@ def test_text_search_setup_with_empty_file_raises_error(
):
# Arrange
data = get_org_files(org_config_with_only_new_file)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
# Assert
assert "Deleted 3 entries. Created 0 new entries for user " in caplog.records[-1].message
verify_embeddings(0, default_user)
@ -79,12 +82,15 @@ def test_text_indexer_deletes_embedding_before_regenerate(
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
assert "Deleting all embeddings for file type org" in caplog.text
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
assert "Deleting all entries for file type org" in caplog.text
assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@ -93,13 +99,14 @@ def test_text_search_setup_batch_processes(content_config: ContentConfig, defaul
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
assert "Created 4 new embeddings" in caplog.text
assert "Created 6 new embeddings" in caplog.text
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@ -122,8 +129,8 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
final_logs = caplog.text
# Assert
assert "Deleting all embeddings for file type org" in initial_logs
assert "Deleting all embeddings for file type org" not in final_logs
assert "Deleting all entries for file type org" in initial_logs
assert "Deleting all entries for file type org" not in final_logs
# ----------------------------------------------------------------------------------------------------
@ -135,7 +142,6 @@ async def test_text_search(search_config: SearchConfig):
default_user = await KhojUser.objects.acreate(
username="test_user", password="test_password", email="test@example.com"
)
# Arrange
org_config = await LocalOrgConfig.objects.acreate(
input_files=None,
input_filter=["tests/data/org/*.org"],
@ -159,13 +165,12 @@ async def test_text_search(search_config: SearchConfig):
# Act
hits = await text_search.query(default_user, query)
# Assert
results = text_search.collate_results(hits)
results = sorted(results, key=lambda x: float(x.score))[:1]
# search results should contain "git clone" entry
# Assert
search_result = results[0].entry
assert "git clone" in search_result
assert "git clone" in search_result, 'search result did not contain "git clone" entry'
# ----------------------------------------------------------------------------------------------------
@ -188,8 +193,9 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
# Assert
# verify newly added org-mode entry is split by max tokens
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
assert (
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens"
# ----------------------------------------------------------------------------------------------------
@ -245,8 +251,9 @@ conda activate khoj
)
# Assert
# verify newly added org-mode entry is split by max tokens
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
assert (
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens"
# ----------------------------------------------------------------------------------------------------
@ -256,27 +263,29 @@ def test_regenerate_index_with_new_entry(
):
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
initial_data = get_org_files(org_config)
# append org-mode entry to first org input file in config
org_config.input_files = [f"{new_org_file}"]
with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
data = get_org_files(org_config)
final_data = get_org_files(org_config)
# Act
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# regenerate notes jsonl, model embeddings and model to include entry from new file
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
final_logs = caplog.text
# Assert
assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message
assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
assert "Deleted 10 entries. Created 11 new entries for user " in final_logs
verify_embeddings(11, default_user)
@ -311,8 +320,8 @@ def test_update_index_with_duplicate_entries_in_stable_order(
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in initial_logs
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in final_logs
assert "Deleted 3 entries. Created 1 new entries for user " in initial_logs
assert "Deleted 0 entries. Created 0 new entries for user " in final_logs
verify_embeddings(1, default_user)
@ -327,29 +336,29 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}{new_entry} -- Tatooine")
data = get_org_files(org_config_with_only_new_file)
# load embeddings, entries, notes model after adding new org file with 2 entries
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
initial_data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}")
data = get_org_files(org_config_with_only_new_file)
final_data = get_org_files(org_config_with_only_new_file)
# Act
# load embeddings, entries, notes model after adding new org file with 2 entries
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in initial_logs
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in final_logs
assert "Deleted 3 entries. Created 2 new entries for user " in initial_logs
assert "Deleted 1 entries. Created 0 new entries for user " in final_logs
verify_embeddings(1, default_user)
@ -379,9 +388,8 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
final_logs = caplog.text
# Assert
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in initial_logs
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in final_logs
assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
assert "Deleted 0 entries. Created 1 new entries for user " in final_logs
verify_embeddings(11, default_user)
@ -390,6 +398,7 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
# Arrange
github_config = GithubConfig.objects.filter(user=default_user).first()
# Act
# Regenerate github embeddings to test asymmetric setup without caching
text_search.setup(