diff --git a/src/khoj/processor/content/docx/docx_to_entries.py b/src/khoj/processor/content/docx/docx_to_entries.py index ab28066d..00ed3ca4 100644 --- a/src/khoj/processor/content/docx/docx_to_entries.py +++ b/src/khoj/processor/content/docx/docx_to_entries.py @@ -19,16 +19,11 @@ class DocxToEntries(TextToEntries): super().__init__() # Define Functions - def process( - self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False - ) -> Tuple[int, int]: + def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config - if not full_corpus: - deletion_file_names = set([file for file in files if files[file] == b""]) - files_to_process = set(files) - deletion_file_names - files = {file: files[file] for file in files_to_process} - else: - deletion_file_names = None + deletion_file_names = set([file for file in files if files[file] == b""]) + files_to_process = set(files) - deletion_file_names + files = {file: files[file] for file in files_to_process} # Extract Entries from specified Docx files with timer("Extract entries from specified DOCX files", logger): diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py index 2aa63d4e..1f3dea00 100644 --- a/src/khoj/processor/content/github/github_to_entries.py +++ b/src/khoj/processor/content/github/github_to_entries.py @@ -48,9 +48,7 @@ class GithubToEntries(TextToEntries): else: return - def process( - self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False - ) -> Tuple[int, int]: + def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: if self.config.pat_token is None or self.config.pat_token == "": logger.error(f"Github PAT token is not set. Skipping github content") raise ValueError("Github PAT token is not set. Skipping github content") diff --git a/src/khoj/processor/content/images/image_to_entries.py b/src/khoj/processor/content/images/image_to_entries.py index 20705a0f..d28518b7 100644 --- a/src/khoj/processor/content/images/image_to_entries.py +++ b/src/khoj/processor/content/images/image_to_entries.py @@ -20,16 +20,11 @@ class ImageToEntries(TextToEntries): super().__init__() # Define Functions - def process( - self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False - ) -> Tuple[int, int]: + def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config - if not full_corpus: - deletion_file_names = set([file for file in files if files[file] == b""]) - files_to_process = set(files) - deletion_file_names - files = {file: files[file] for file in files_to_process} - else: - deletion_file_names = None + deletion_file_names = set([file for file in files if files[file] == b""]) + files_to_process = set(files) - deletion_file_names + files = {file: files[file] for file in files_to_process} # Extract Entries from specified image files with timer("Extract entries from specified Image files", logger): diff --git a/src/khoj/processor/content/markdown/markdown_to_entries.py b/src/khoj/processor/content/markdown/markdown_to_entries.py index f18e1e21..fdb0c549 100644 --- a/src/khoj/processor/content/markdown/markdown_to_entries.py +++ b/src/khoj/processor/content/markdown/markdown_to_entries.py @@ -19,16 +19,11 @@ class MarkdownToEntries(TextToEntries): super().__init__() # Define Functions - def process( - self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False - ) -> Tuple[int, int]: + def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config - if not full_corpus: - deletion_file_names = set([file for file in files if files[file] == ""]) - files_to_process = set(files) - deletion_file_names - files = {file: files[file] for file in files_to_process} - else: - deletion_file_names = None + deletion_file_names = set([file for file in files if files[file] == ""]) + files_to_process = set(files) - deletion_file_names + files = {file: files[file] for file in files_to_process} max_tokens = 256 # Extract Entries from specified Markdown files diff --git a/src/khoj/processor/content/notion/notion_to_entries.py b/src/khoj/processor/content/notion/notion_to_entries.py index 57456ed5..c53d4020 100644 --- a/src/khoj/processor/content/notion/notion_to_entries.py +++ b/src/khoj/processor/content/notion/notion_to_entries.py @@ -78,9 +78,7 @@ class NotionToEntries(TextToEntries): self.body_params = {"page_size": 100} - def process( - self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False - ) -> Tuple[int, int]: + def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: current_entries = [] # Get all pages diff --git a/src/khoj/processor/content/org_mode/org_to_entries.py b/src/khoj/processor/content/org_mode/org_to_entries.py index c528244d..1272da11 100644 --- a/src/khoj/processor/content/org_mode/org_to_entries.py +++ b/src/khoj/processor/content/org_mode/org_to_entries.py @@ -20,15 +20,10 @@ class OrgToEntries(TextToEntries): super().__init__() # Define Functions - def process( - self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False - ) -> Tuple[int, int]: - if not full_corpus: - deletion_file_names = set([file for file in files if files[file] == ""]) - files_to_process = set(files) - deletion_file_names - files = {file: files[file] for file in files_to_process} - else: - deletion_file_names = None + def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + deletion_file_names = set([file for file in files if files[file] == ""]) + files_to_process = set(files) - deletion_file_names + files = {file: files[file] for file in files_to_process} # Extract Entries from specified Org files max_tokens = 256 diff --git a/src/khoj/processor/content/pdf/pdf_to_entries.py b/src/khoj/processor/content/pdf/pdf_to_entries.py index 45ff7261..59ffc388 100644 --- a/src/khoj/processor/content/pdf/pdf_to_entries.py +++ b/src/khoj/processor/content/pdf/pdf_to_entries.py @@ -22,16 +22,11 @@ class PdfToEntries(TextToEntries): super().__init__() # Define Functions - def process( - self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False - ) -> Tuple[int, int]: + def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: # Extract required fields from config - if not full_corpus: - deletion_file_names = set([file for file in files if files[file] == b""]) - files_to_process = set(files) - deletion_file_names - files = {file: files[file] for file in files_to_process} - else: - deletion_file_names = None + deletion_file_names = set([file for file in files if files[file] == b""]) + files_to_process = set(files) - deletion_file_names + files = {file: files[file] for file in files_to_process} # Extract Entries from specified Pdf files with timer("Extract entries from specified PDF files", logger): diff --git a/src/khoj/processor/content/plaintext/plaintext_to_entries.py b/src/khoj/processor/content/plaintext/plaintext_to_entries.py index 2c994899..483e752f 100644 --- a/src/khoj/processor/content/plaintext/plaintext_to_entries.py +++ b/src/khoj/processor/content/plaintext/plaintext_to_entries.py @@ -20,15 +20,10 @@ class PlaintextToEntries(TextToEntries): super().__init__() # Define Functions - def process( - self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False - ) -> Tuple[int, int]: - if not full_corpus: - deletion_file_names = set([file for file in files if files[file] == ""]) - files_to_process = set(files) - deletion_file_names - files = {file: files[file] for file in files_to_process} - else: - deletion_file_names = None + def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: + deletion_file_names = set([file for file in files if files[file] == ""]) + files_to_process = set(files) - deletion_file_names + files = {file: files[file] for file in files_to_process} # Extract Entries from specified plaintext files with timer("Extract entries from specified Plaintext files", logger): diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index cdb2e207..6fee9c0c 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -31,9 +31,7 @@ class TextToEntries(ABC): self.date_filter = DateFilter() @abstractmethod - def process( - self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False - ) -> Tuple[int, int]: + def process(self, files: dict[str, str] = None, user: KhojUser = None, regenerate: bool = False) -> Tuple[int, int]: ... @staticmethod diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 9bff0dc6..10984fdd 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1313,7 +1313,6 @@ def configure_content( files: Optional[dict[str, dict[str, str]]], regenerate: bool = False, t: Optional[state.SearchType] = state.SearchType.All, - full_corpus: bool = True, user: KhojUser = None, ) -> bool: success = True @@ -1344,7 +1343,6 @@ def configure_content( OrgToEntries, files.get("org"), regenerate=regenerate, - full_corpus=full_corpus, user=user, ) except Exception as e: @@ -1362,7 +1360,6 @@ def configure_content( MarkdownToEntries, files.get("markdown"), regenerate=regenerate, - full_corpus=full_corpus, user=user, ) @@ -1379,7 +1376,6 @@ def configure_content( PdfToEntries, files.get("pdf"), regenerate=regenerate, - full_corpus=full_corpus, user=user, ) @@ -1398,7 +1394,6 @@ def configure_content( PlaintextToEntries, files.get("plaintext"), regenerate=regenerate, - full_corpus=full_corpus, user=user, ) @@ -1418,7 +1413,6 @@ def configure_content( GithubToEntries, None, regenerate=regenerate, - full_corpus=full_corpus, user=user, config=github_config, ) @@ -1439,7 +1433,6 @@ def configure_content( NotionToEntries, None, regenerate=regenerate, - full_corpus=full_corpus, user=user, config=notion_config, ) @@ -1459,7 +1452,6 @@ def configure_content( ImageToEntries, files.get("image"), regenerate=regenerate, - full_corpus=full_corpus, user=user, ) except Exception as e: @@ -1472,7 +1464,6 @@ def configure_content( DocxToEntries, files.get("docx"), regenerate=regenerate, - full_corpus=full_corpus, user=user, ) except Exception as e: diff --git a/src/khoj/routers/notion.py b/src/khoj/routers/notion.py index 9f5d803f..e61b5fd7 100644 --- a/src/khoj/routers/notion.py +++ b/src/khoj/routers/notion.py @@ -80,6 +80,6 @@ async def notion_auth_callback(request: Request, background_tasks: BackgroundTas notion_redirect = str(request.app.url_path_for("notion_config_page")) # Trigger an async job to configure_content. Let it run without blocking the response. - background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, True, user) + background_tasks.add_task(run_in_executor, configure_content, {}, False, SearchType.Notion, user) return RedirectResponse(notion_redirect) diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index f3ce7110..93a2b724 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -199,17 +199,16 @@ def setup( text_to_entries: Type[TextToEntries], files: dict[str, str], regenerate: bool, - full_corpus: bool = True, user: KhojUser = None, config=None, -) -> None: +) -> Tuple[int, int]: if config: num_new_embeddings, num_deleted_embeddings = text_to_entries(config).process( - files=files, full_corpus=full_corpus, user=user, regenerate=regenerate + files=files, user=user, regenerate=regenerate ) else: num_new_embeddings, num_deleted_embeddings = text_to_entries().process( - files=files, full_corpus=full_corpus, user=user, regenerate=regenerate + files=files, user=user, regenerate=regenerate ) if files: @@ -219,6 +218,8 @@ def setup( f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names[:10]} ..." ) + return num_new_embeddings, num_deleted_embeddings + def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]: """Score all retrieved entries using the cross-encoder""" diff --git a/src/khoj/utils/fs_syncer.py b/src/khoj/utils/fs_syncer.py index 5a20f418..ade55f34 100644 --- a/src/khoj/utils/fs_syncer.py +++ b/src/khoj/utils/fs_syncer.py @@ -122,7 +122,7 @@ def get_org_files(config: TextContentConfig): logger.debug("At least one of org-files or org-file-filter is required to be specified") return {} - "Get Org files to process" + # Get Org files to process absolute_org_files, filtered_org_files = set(), set() if org_files: absolute_org_files = {get_absolute_path(org_file) for org_file in org_files} diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 915425bf..4529aa53 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -6,9 +6,16 @@ from pathlib import Path import pytest +from khoj.database.adapters import EntryAdapters from khoj.database.models import Entry, GithubConfig, KhojUser, LocalOrgConfig +from khoj.processor.content.docx.docx_to_entries import DocxToEntries from khoj.processor.content.github.github_to_entries import GithubToEntries +from khoj.processor.content.images.image_to_entries import ImageToEntries +from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries from khoj.processor.content.org_mode.org_to_entries import OrgToEntries +from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries +from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries +from khoj.processor.content.text_to_entries import TextToEntries from khoj.search_type import text_search from khoj.utils.fs_syncer import collect_files, get_org_files from khoj.utils.rawconfig import ContentConfig, SearchConfig @@ -151,7 +158,6 @@ async def test_text_search(search_config: SearchConfig): OrgToEntries, data, True, - True, default_user, ) @@ -240,7 +246,6 @@ conda activate khoj OrgToEntries, data, regenerate=False, - full_corpus=False, user=default_user, ) @@ -396,6 +401,49 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file verify_embeddings(3, default_user) +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db +@pytest.mark.parametrize( + "text_to_entries", + [ + (OrgToEntries), + ], +) +def test_update_index_with_deleted_file( + org_config_with_only_new_file: LocalOrgConfig, text_to_entries: TextToEntries, default_user: KhojUser +): + "Delete entries associated with new file when file path with empty content passed." + # Arrange + file_to_index = "test" + new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" + initial_data = {file_to_index: new_entry} + final_data = {file_to_index: ""} + + # Act + # load entries after adding file + initial_added_entries, _ = text_search.setup(text_to_entries, initial_data, regenerate=True, user=default_user) + initial_total_entries = EntryAdapters.get_existing_entry_hashes_by_file(default_user, file_to_index).count() + + # load entries after deleting file + final_added_entries, final_deleted_entries = text_search.setup( + text_to_entries, final_data, regenerate=False, user=default_user + ) + final_total_entries = EntryAdapters.get_existing_entry_hashes_by_file(default_user, file_to_index).count() + + # Assert + assert initial_total_entries > 0, "File entries not indexed" + assert initial_added_entries > 0, "No entries got added" + + assert final_total_entries == 0, "File did not get deleted" + assert final_added_entries == 0, "Entries were unexpectedly added in delete entries pass" + assert final_deleted_entries == initial_added_entries, "All added entries were not deleted" + + verify_embeddings(0, default_user), "Embeddings still exist for user" + + # Clean up + EntryAdapters.delete_all_entries(default_user) + + # ---------------------------------------------------------------------------------------------------- @pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set") def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):