From 9ab327a2b6b30891896786c6f53408479dfa243c Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 6 Nov 2023 23:49:08 -0800 Subject: [PATCH] Store the data source of each entry in database This will be useful for updating, deleting entries by their data source. Data source can be one of Computer, Github or Notion for now Store each file/entries source in database --- src/database/adapters/__init__.py | 22 ++++++++++++++++--- .../migrations/0012_entry_file_source.py | 21 ++++++++++++++++++ src/database/models/__init__.py | 6 +++++ .../processor/github/github_to_entries.py | 7 +++++- .../processor/markdown/markdown_to_entries.py | 1 + .../processor/notion/notion_to_entries.py | 7 +++++- src/khoj/processor/org_mode/org_to_entries.py | 1 + src/khoj/processor/pdf/pdf_to_entries.py | 1 + .../plaintext/plaintext_to_entries.py | 1 + src/khoj/processor/text_to_entries.py | 4 +++- src/khoj/search_type/text_search.py | 9 ++++---- tests/test_text_search.py | 2 +- 12 files changed, 71 insertions(+), 11 deletions(-) create mode 100644 src/database/migrations/0012_entry_file_source.py diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index fa37aa99..69a3c1f4 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -287,13 +287,21 @@ class EntryAdapters: return deleted_count @staticmethod - def delete_all_entries(user: KhojUser, file_type: str = None): + def delete_all_entries_by_type(user: KhojUser, file_type: str = None): if file_type is None: deleted_count, _ = Entry.objects.filter(user=user).delete() else: deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete() return deleted_count + @staticmethod + def delete_all_entries_by_source(user: KhojUser, file_source: str = None): + if file_source is None: + deleted_count, _ = Entry.objects.filter(user=user).delete() + else: + deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete() + return deleted_count + @staticmethod def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) @@ -318,8 +326,12 @@ class EntryAdapters: return await Entry.objects.filter(user=user, file_path=file_path).adelete() @staticmethod - def aget_all_filenames(user: KhojUser): - return Entry.objects.filter(user=user).distinct("file_path").values_list("file_path", flat=True) + def aget_all_filenames_by_source(user: KhojUser, file_source: str): + return ( + Entry.objects.filter(user=user, file_source=file_source) + .distinct("file_path") + .values_list("file_path", flat=True) + ) @staticmethod async def adelete_all_entries(user: KhojUser): @@ -384,3 +396,7 @@ class EntryAdapters: @staticmethod def get_unique_file_types(user: KhojUser): return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() + + @staticmethod + def get_unique_file_source(user: KhojUser): + return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct() diff --git a/src/database/migrations/0012_entry_file_source.py b/src/database/migrations/0012_entry_file_source.py new file mode 100644 index 00000000..187136ae --- /dev/null +++ b/src/database/migrations/0012_entry_file_source.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.5 on 2023-11-07 07:24 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0011_merge_20231102_0138"), + ] + + operations = [ + migrations.AddField( + model_name="entry", + name="file_source", + field=models.CharField( + choices=[("computer", "Computer"), ("notion", "Notion"), ("github", "Github")], + default="computer", + max_length=30, + ), + ), + ] diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index 5dd9622b..b1be9ded 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -131,11 +131,17 @@ class Entry(BaseModel): GITHUB = "github" CONVERSATION = "conversation" + class EntrySource(models.TextChoices): + COMPUTER = "computer" + NOTION = "notion" + GITHUB = "github" + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True) embeddings = VectorField(dimensions=384) raw = models.TextField() compiled = models.TextField() heading = models.CharField(max_length=1000, default=None, null=True, blank=True) + file_source = models.CharField(max_length=30, choices=EntrySource.choices, default=EntrySource.COMPUTER) file_type = models.CharField(max_length=30, choices=EntryType.choices, default=EntryType.PLAINTEXT) file_path = models.CharField(max_length=400, default=None, null=True, blank=True) file_name = models.CharField(max_length=400, default=None, null=True, blank=True) diff --git a/src/khoj/processor/github/github_to_entries.py b/src/khoj/processor/github/github_to_entries.py index 14e9b696..56279453 100644 --- a/src/khoj/processor/github/github_to_entries.py +++ b/src/khoj/processor/github/github_to_entries.py @@ -104,7 +104,12 @@ class GithubToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( - current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user + current_entries, + DbEntry.EntryType.GITHUB, + DbEntry.EntrySource.GITHUB, + key="compiled", + logger=logger, + user=user, ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/markdown/markdown_to_entries.py b/src/khoj/processor/markdown/markdown_to_entries.py index e0b76368..0dd71740 100644 --- a/src/khoj/processor/markdown/markdown_to_entries.py +++ b/src/khoj/processor/markdown/markdown_to_entries.py @@ -47,6 +47,7 @@ class MarkdownToEntries(TextToEntries): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, DbEntry.EntryType.MARKDOWN, + DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, diff --git a/src/khoj/processor/notion/notion_to_entries.py b/src/khoj/processor/notion/notion_to_entries.py index a4b15d4e..7a88e2a1 100644 --- a/src/khoj/processor/notion/notion_to_entries.py +++ b/src/khoj/processor/notion/notion_to_entries.py @@ -250,7 +250,12 @@ class NotionToEntries(TextToEntries): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( - current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user + current_entries, + DbEntry.EntryType.NOTION, + DbEntry.EntrySource.NOTION, + key="compiled", + logger=logger, + user=user, ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/org_mode/org_to_entries.py b/src/khoj/processor/org_mode/org_to_entries.py index bf6df6dc..04ce97e4 100644 --- a/src/khoj/processor/org_mode/org_to_entries.py +++ b/src/khoj/processor/org_mode/org_to_entries.py @@ -48,6 +48,7 @@ class OrgToEntries(TextToEntries): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, DbEntry.EntryType.ORG, + DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, diff --git a/src/khoj/processor/pdf/pdf_to_entries.py b/src/khoj/processor/pdf/pdf_to_entries.py index 81c2250f..3a47096a 100644 --- a/src/khoj/processor/pdf/pdf_to_entries.py +++ b/src/khoj/processor/pdf/pdf_to_entries.py @@ -46,6 +46,7 @@ class PdfToEntries(TextToEntries): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, DbEntry.EntryType.PDF, + DbEntry.EntrySource.COMPUTER, "compiled", logger, deletion_file_names, diff --git a/src/khoj/processor/plaintext/plaintext_to_entries.py b/src/khoj/processor/plaintext/plaintext_to_entries.py index fd5e1de2..d42dae30 100644 --- a/src/khoj/processor/plaintext/plaintext_to_entries.py +++ b/src/khoj/processor/plaintext/plaintext_to_entries.py @@ -56,6 +56,7 @@ class PlaintextToEntries(TextToEntries): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, DbEntry.EntryType.PLAINTEXT, + DbEntry.EntrySource.COMPUTER, key="compiled", logger=logger, deletion_filenames=deletion_file_names, diff --git a/src/khoj/processor/text_to_entries.py b/src/khoj/processor/text_to_entries.py index 4661fd9b..3d79e02e 100644 --- a/src/khoj/processor/text_to_entries.py +++ b/src/khoj/processor/text_to_entries.py @@ -78,6 +78,7 @@ class TextToEntries(ABC): self, current_entries: List[Entry], file_type: str, + file_source: str, key="compiled", logger: logging.Logger = None, deletion_filenames: Set[str] = None, @@ -95,7 +96,7 @@ class TextToEntries(ABC): if regenerate: with timer("Cleared existing dataset for regeneration in", logger): logger.debug(f"Deleting all entries for file type {file_type}") - num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type) + num_deleted_entries = EntryAdapters.delete_all_entries_by_type(user, file_type) hashes_to_process = set() with timer("Identified entries to add to database in", logger): @@ -132,6 +133,7 @@ class TextToEntries(ABC): compiled=entry.compiled, heading=entry.heading[:1000], # Truncate to max chars of field allowed file_path=entry.file, + file_source=file_source, file_type=file_type, hashed_value=entry_hash, corpus_id=entry.corpus_id, diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 14f5b770..ba2fc9ec 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -204,11 +204,12 @@ def setup( files=files, full_corpus=full_corpus, user=user, regenerate=regenerate ) - file_names = [file_name for file_name in files] + if files: + file_names = [file_name for file_name in files] - logger.info( - f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}" - ) + logger.info( + f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}" + ) def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]: diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 7d8c30fb..3d729ab5 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -58,7 +58,7 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -def test_text_search_setup_with_empty_file_raises_error( +def test_text_search_setup_with_empty_file_creates_no_entries( org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog ): # Arrange