Accept file deletion requests by clients during sync

- Remove unused full_corpus boolean. The full_corpus=False code path
  wasn't being used (accept for in a test)
- The full_corpus=True code path used was ignoring file deletion
  requests sent by clients during sync. Unclear why this was done
- Added unit test to prevent regression and show file deletion by
  clients during sync not ignored now
This commit is contained in:
Debanjum Singh Solanky 2024-07-18 23:24:12 +05:30
parent 5923b6d89e
commit bba4e0b529
14 changed files with 84 additions and 80 deletions

View file

@ -19,16 +19,11 @@ class DocxToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
if not full_corpus: deletion_file_names = set([file for file in files if files[file] == b""])
deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names
files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process}
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified Docx files # Extract Entries from specified Docx files
with timer("Extract entries from specified DOCX files", logger): with timer("Extract entries from specified DOCX files", logger):

View file

@ -48,9 +48,7 @@ class GithubToEntries(TextToEntries):
else: else:
return return
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
if self.config.pat_token is None or self.config.pat_token == "": if self.config.pat_token is None or self.config.pat_token == "":
logger.error(f"Github PAT token is not set. Skipping github content") logger.error(f"Github PAT token is not set. Skipping github content")
raise ValueError("Github PAT token is not set. Skipping github content") raise ValueError("Github PAT token is not set. Skipping github content")

View file

@ -20,16 +20,11 @@ class ImageToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
if not full_corpus: deletion_file_names = set([file for file in files if files[file] == b""])
deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names
files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process}
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified image files # Extract Entries from specified image files
with timer("Extract entries from specified Image files", logger): with timer("Extract entries from specified Image files", logger):

View file

@ -19,16 +19,11 @@ class MarkdownToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
if not full_corpus: deletion_file_names = set([file for file in files if files[file] == ""])
deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names
files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process}
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
max_tokens = 256 max_tokens = 256
# Extract Entries from specified Markdown files # Extract Entries from specified Markdown files

View file

@ -78,9 +78,7 @@ class NotionToEntries(TextToEntries):
self.body_params = {"page_size": 100} self.body_params = {"page_size": 100}
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
current_entries = [] current_entries = []
# Get all pages # Get all pages

View file

@ -20,15 +20,10 @@ class OrgToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False deletion_file_names = set([file for file in files if files[file] == ""])
) -> Tuple[int, int]: files_to_process = set(files) - deletion_file_names
if not full_corpus: files = {file: files[file] for file in files_to_process}
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified Org files # Extract Entries from specified Org files
max_tokens = 256 max_tokens = 256

View file

@ -22,16 +22,11 @@ class PdfToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
if not full_corpus: deletion_file_names = set([file for file in files if files[file] == b""])
deletion_file_names = set([file for file in files if files[file] == b""]) files_to_process = set(files) - deletion_file_names
files_to_process = set(files) - deletion_file_names files = {file: files[file] for file in files_to_process}
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified Pdf files # Extract Entries from specified Pdf files
with timer("Extract entries from specified PDF files", logger): with timer("Extract entries from specified PDF files", logger):

View file

@ -20,15 +20,10 @@ class PlaintextToEntries(TextToEntries):
super().__init__() super().__init__()
# Define Functions # Define Functions
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False deletion_file_names = set([file for file in files if files[file] == ""])
) -> Tuple[int, int]: files_to_process = set(files) - deletion_file_names
if not full_corpus: files = {file: files[file] for file in files_to_process}
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified plaintext files # Extract Entries from specified plaintext files
with timer("Extract entries from specified Plaintext files", logger): with timer("Extract entries from specified Plaintext files", logger):

View file

@ -31,9 +31,7 @@ class TextToEntries(ABC):
self.date_filter = DateFilter() self.date_filter = DateFilter()
@abstractmethod @abstractmethod
def process( def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
... ...
@staticmethod @staticmethod

View file

@ -1313,7 +1313,6 @@ def configure_content(
files: Optional[dict[str, dict[str, str]]], files: Optional[dict[str, dict[str, str]]],
regenerate: bool = False, regenerate: bool = False,
t: Optional[state.SearchType] = state.SearchType.All, t: Optional[state.SearchType] = state.SearchType.All,
full_corpus: bool = True,
user: KhojUser = None, user: KhojUser = None,
) -> bool: ) -> bool:
success = True success = True
@ -1344,7 +1343,6 @@ def configure_content(
OrgToEntries, OrgToEntries,
files.get("org"), files.get("org"),
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
) )
except Exception as e: except Exception as e:
@ -1362,7 +1360,6 @@ def configure_content(
MarkdownToEntries, MarkdownToEntries,
files.get("markdown"), files.get("markdown"),
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
) )
@ -1379,7 +1376,6 @@ def configure_content(
PdfToEntries, PdfToEntries,
files.get("pdf"), files.get("pdf"),
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
) )
@ -1398,7 +1394,6 @@ def configure_content(
PlaintextToEntries, PlaintextToEntries,
files.get("plaintext"), files.get("plaintext"),
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
) )
@ -1418,7 +1413,6 @@ def configure_content(
GithubToEntries, GithubToEntries,
None, None,
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
config=github_config, config=github_config,
) )
@ -1439,7 +1433,6 @@ def configure_content(
NotionToEntries, NotionToEntries,
None, None,
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
config=notion_config, config=notion_config,
) )
@ -1459,7 +1452,6 @@ def configure_content(
ImageToEntries, ImageToEntries,
files.get("image"), files.get("image"),
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
) )
except Exception as e: except Exception as e:
@ -1472,7 +1464,6 @@ def configure_content(
DocxToEntries, DocxToEntries,
files.get("docx"), files.get("docx"),
regenerate=regenerate, regenerate=regenerate,
full_corpus=full_corpus,
user=user, user=user,
) )
except Exception as e: except Exception as e:

View file

@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas
notion_redirect = str(request.app.url_path_for("notion_config_page")) notion_redirect = str(request.app.url_path_for("notion_config_page"))
# Trigger an async job to configure_content. Let it run without blocking the response. # Trigger an async job to configure_content. Let it run without blocking the response.
background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, True, user) background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user)
return RedirectResponse(notion_redirect) return RedirectResponse(notion_redirect)

View file

@ -199,17 +199,16 @@ def setup(
text_to_entries: Type[TextToEntries], text_to_entries: Type[TextToEntries],
files: dict[str, str], files: dict[str, str],
regenerate: bool, regenerate: bool,
full_corpus: bool = True,
user: KhojUser = None, user: KhojUser = None,
config=None, config=None,
) -> None: ) -> Tuple[int, int]:
if config: if config:
num_new_embeddings, num_deleted_embeddings = text_to_entries(config).process( num_new_embeddings, num_deleted_embeddings = text_to_entries(config).process(
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate files=files, user=user, regenerate=regenerate
) )
else: else:
num_new_embeddings, num_deleted_embeddings = text_to_entries().process( num_new_embeddings, num_deleted_embeddings = text_to_entries().process(
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate files=files, user=user, regenerate=regenerate
) )
if files: if files:
@ -219,6 +218,8 @@ def setup(
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names[:10]} ..." f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names[:10]} ..."
) )
return num_new_embeddings, num_deleted_embeddings
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]: def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
"""Score all retrieved entries using the cross-encoder""" """Score all retrieved entries using the cross-encoder"""

View file

@ -122,7 +122,7 @@ def get_org_files(config: TextContentConfig):
logger.debug("At least one of org-files or org-file-filter is required to be specified") logger.debug("At least one of org-files or org-file-filter is required to be specified")
return {} return {}
"Get Org files to process" # Get Org files to process
absolute_org_files, filtered_org_files = set(), set() absolute_org_files, filtered_org_files = set(), set()
if org_files: if org_files:
absolute_org_files = {get_absolute_path(org_file) for org_file in org_files} absolute_org_files = {get_absolute_path(org_file) for org_file in org_files}

View file

@ -6,9 +6,16 @@ from pathlib import Path
import pytest import pytest
from khoj.database.adapters import EntryAdapters
from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
from khoj.processor.content.github.github_to_entries import GithubToEntries from khoj.processor.content.github.github_to_entries import GithubToEntries
from khoj.processor.content.images.image_to_entries import ImageToEntries
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.search_type import text_search from khoj.search_type import text_search
from khoj.utils.fs_syncer import collect_files, get_org_files from khoj.utils.fs_syncer import collect_files, get_org_files
from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, SearchConfig
@ -151,7 +158,6 @@ async def test_text_search(search_config: SearchConfig):
OrgToEntries, OrgToEntries,
data, data,
True, True,
True,
default_user, default_user,
) )
@ -240,7 +246,6 @@ conda activate khoj
OrgToEntries, OrgToEntries,
data, data,
regenerate=False, regenerate=False,
full_corpus=False,
user=default_user, user=default_user,
) )
@ -396,6 +401,49 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
verify_embeddings(3, default_user) verify_embeddings(3, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
@pytest.mark.parametrize(
"text_to_entries",
[
(OrgToEntries),
],
)
def test_update_index_with_deleted_file(
org_config_with_only_new_file: LocalOrgConfig, text_to_entries: TextToEntries, default_user: KhojUser
):
"Delete entries associated with new file when file path with empty content passed."
# Arrange
file_to_index = "test"
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
initial_data = {file_to_index: new_entry}
final_data = {file_to_index: ""}
# Act
# load entries after adding file
initial_added_entries, _ = text_search.setup(text_to_entries, initial_data, regenerate=True, user=default_user)
initial_total_entries = EntryAdapters.get_existing_entry_hashes_by_file(default_user, file_to_index).count()
# load entries after deleting file
final_added_entries, final_deleted_entries = text_search.setup(
text_to_entries, final_data, regenerate=False, user=default_user
)
final_total_entries = EntryAdapters.get_existing_entry_hashes_by_file(default_user, file_to_index).count()
# Assert
assert initial_total_entries > 0, "File entries not indexed"
assert initial_added_entries > 0, "No entries got added"
assert final_total_entries == 0, "File did not get deleted"
assert final_added_entries == 0, "Entries were unexpectedly added in delete entries pass"
assert final_deleted_entries == initial_added_entries, "All added entries were not deleted"
verify_embeddings(0, default_user), "Embeddings still exist for user"
# Clean up
EntryAdapters.delete_all_entries(default_user)
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set") @pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser): def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):