Store the data source of each entry in database

This will be useful for updating, deleting entries by their data
source. Data source can be one of Computer, Github or Notion for now

Store each file/entries source in database
This commit is contained in:
Debanjum Singh Solanky 2023-11-06 23:49:08 -08:00
parent c82cd0862a
commit 9ab327a2b6
12 changed files with 71 additions and 11 deletions

View file

@ -287,13 +287,21 @@ class EntryAdapters:
return deleted_count
@staticmethod
def delete_all_entries(user: KhojUser, file_type: str = None):
def delete_all_entries_by_type(user: KhojUser, file_type: str = None):
if file_type is None:
deleted_count, _ = Entry.objects.filter(user=user).delete()
else:
deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete()
return deleted_count
@staticmethod
def delete_all_entries_by_source(user: KhojUser, file_source: str = None):
if file_source is None:
deleted_count, _ = Entry.objects.filter(user=user).delete()
else:
deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete()
return deleted_count
@staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@ -318,8 +326,12 @@ class EntryAdapters:
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
@staticmethod
def aget_all_filenames(user: KhojUser):
return Entry.objects.filter(user=user).distinct("file_path").values_list("file_path", flat=True)
def aget_all_filenames_by_source(user: KhojUser, file_source: str):
return (
Entry.objects.filter(user=user, file_source=file_source)
.distinct("file_path")
.values_list("file_path", flat=True)
)
@staticmethod
async def adelete_all_entries(user: KhojUser):
@ -384,3 +396,7 @@ class EntryAdapters:
@staticmethod
def get_unique_file_types(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
@staticmethod
def get_unique_file_source(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct()

View file

@ -0,0 +1,21 @@
# Generated by Django 4.2.5 on 2023-11-07 07:24
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0011_merge_20231102_0138"),
]
operations = [
migrations.AddField(
model_name="entry",
name="file_source",
field=models.CharField(
choices=[("computer", "Computer"), ("notion", "Notion"), ("github", "Github")],
default="computer",
max_length=30,
),
),
]

View file

@ -131,11 +131,17 @@ class Entry(BaseModel):
GITHUB = "github"
CONVERSATION = "conversation"
class EntrySource(models.TextChoices):
COMPUTER = "computer"
NOTION = "notion"
GITHUB = "github"
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
embeddings = VectorField(dimensions=384)
raw = models.TextField()
compiled = models.TextField()
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
file_source = models.CharField(max_length=30, choices=EntrySource.choices, default=EntrySource.COMPUTER)
file_type = models.CharField(max_length=30, choices=EntryType.choices, default=EntryType.PLAINTEXT)
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)

View file

@ -104,7 +104,12 @@ class GithubToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user
current_entries,
DbEntry.EntryType.GITHUB,
DbEntry.EntrySource.GITHUB,
key="compiled",
logger=logger,
user=user,
)
return num_new_embeddings, num_deleted_embeddings

View file

@ -47,6 +47,7 @@ class MarkdownToEntries(TextToEntries):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.MARKDOWN,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,

View file

@ -250,7 +250,12 @@ class NotionToEntries(TextToEntries):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user
current_entries,
DbEntry.EntryType.NOTION,
DbEntry.EntrySource.NOTION,
key="compiled",
logger=logger,
user=user,
)
return num_new_embeddings, num_deleted_embeddings

View file

@ -48,6 +48,7 @@ class OrgToEntries(TextToEntries):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.ORG,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,

View file

@ -46,6 +46,7 @@ class PdfToEntries(TextToEntries):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.PDF,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,

View file

@ -56,6 +56,7 @@ class PlaintextToEntries(TextToEntries):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.PLAINTEXT,
DbEntry.EntrySource.COMPUTER,
key="compiled",
logger=logger,
deletion_filenames=deletion_file_names,

View file

@ -78,6 +78,7 @@ class TextToEntries(ABC):
self,
current_entries: List[Entry],
file_type: str,
file_source: str,
key="compiled",
logger: logging.Logger = None,
deletion_filenames: Set[str] = None,
@ -95,7 +96,7 @@ class TextToEntries(ABC):
if regenerate:
with timer("Cleared existing dataset for regeneration in", logger):
logger.debug(f"Deleting all entries for file type {file_type}")
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
num_deleted_entries = EntryAdapters.delete_all_entries_by_type(user, file_type)
hashes_to_process = set()
with timer("Identified entries to add to database in", logger):
@ -132,6 +133,7 @@ class TextToEntries(ABC):
compiled=entry.compiled,
heading=entry.heading[:1000], # Truncate to max chars of field allowed
file_path=entry.file,
file_source=file_source,
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,

View file

@ -204,11 +204,12 @@ def setup(
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
)
file_names = [file_name for file_name in files]
if files:
file_names = [file_name for file_name in files]
logger.info(
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
)
logger.info(
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
)
def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:

View file

@ -58,7 +58,7 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_with_empty_file_raises_error(
def test_text_search_setup_with_empty_file_creates_no_entries(
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange