mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Store the data source of each entry in database
This will be useful for updating, deleting entries by their data source. Data source can be one of Computer, Github or Notion for now Store each file/entries source in database
This commit is contained in:
parent
c82cd0862a
commit
9ab327a2b6
12 changed files with 71 additions and 11 deletions
|
@ -287,13 +287,21 @@ class EntryAdapters:
|
|||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def delete_all_entries(user: KhojUser, file_type: str = None):
|
||||
def delete_all_entries_by_type(user: KhojUser, file_type: str = None):
|
||||
if file_type is None:
|
||||
deleted_count, _ = Entry.objects.filter(user=user).delete()
|
||||
else:
|
||||
deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete()
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def delete_all_entries_by_source(user: KhojUser, file_source: str = None):
|
||||
if file_source is None:
|
||||
deleted_count, _ = Entry.objects.filter(user=user).delete()
|
||||
else:
|
||||
deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete()
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
||||
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
||||
|
@ -318,8 +326,12 @@ class EntryAdapters:
|
|||
return await Entry.objects.filter(user=user, file_path=file_path).adelete()
|
||||
|
||||
@staticmethod
|
||||
def aget_all_filenames(user: KhojUser):
|
||||
return Entry.objects.filter(user=user).distinct("file_path").values_list("file_path", flat=True)
|
||||
def aget_all_filenames_by_source(user: KhojUser, file_source: str):
|
||||
return (
|
||||
Entry.objects.filter(user=user, file_source=file_source)
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def adelete_all_entries(user: KhojUser):
|
||||
|
@ -384,3 +396,7 @@ class EntryAdapters:
|
|||
@staticmethod
|
||||
def get_unique_file_types(user: KhojUser):
|
||||
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
||||
|
||||
@staticmethod
|
||||
def get_unique_file_source(user: KhojUser):
|
||||
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct()
|
||||
|
|
21
src/database/migrations/0012_entry_file_source.py
Normal file
21
src/database/migrations/0012_entry_file_source.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-07 07:24
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0011_merge_20231102_0138"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="entry",
|
||||
name="file_source",
|
||||
field=models.CharField(
|
||||
choices=[("computer", "Computer"), ("notion", "Notion"), ("github", "Github")],
|
||||
default="computer",
|
||||
max_length=30,
|
||||
),
|
||||
),
|
||||
]
|
|
@ -131,11 +131,17 @@ class Entry(BaseModel):
|
|||
GITHUB = "github"
|
||||
CONVERSATION = "conversation"
|
||||
|
||||
class EntrySource(models.TextChoices):
|
||||
COMPUTER = "computer"
|
||||
NOTION = "notion"
|
||||
GITHUB = "github"
|
||||
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
embeddings = VectorField(dimensions=384)
|
||||
raw = models.TextField()
|
||||
compiled = models.TextField()
|
||||
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
|
||||
file_source = models.CharField(max_length=30, choices=EntrySource.choices, default=EntrySource.COMPUTER)
|
||||
file_type = models.CharField(max_length=30, choices=EntryType.choices, default=EntryType.PLAINTEXT)
|
||||
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
|
|
|
@ -104,7 +104,12 @@ class GithubToEntries(TextToEntries):
|
|||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user
|
||||
current_entries,
|
||||
DbEntry.EntryType.GITHUB,
|
||||
DbEntry.EntrySource.GITHUB,
|
||||
key="compiled",
|
||||
logger=logger,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
|
|
@ -47,6 +47,7 @@ class MarkdownToEntries(TextToEntries):
|
|||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
DbEntry.EntryType.MARKDOWN,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
|
|
|
@ -250,7 +250,12 @@ class NotionToEntries(TextToEntries):
|
|||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user
|
||||
current_entries,
|
||||
DbEntry.EntryType.NOTION,
|
||||
DbEntry.EntrySource.NOTION,
|
||||
key="compiled",
|
||||
logger=logger,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
|
|
@ -48,6 +48,7 @@ class OrgToEntries(TextToEntries):
|
|||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
DbEntry.EntryType.ORG,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
|
|
|
@ -46,6 +46,7 @@ class PdfToEntries(TextToEntries):
|
|||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
DbEntry.EntryType.PDF,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
|
|
|
@ -56,6 +56,7 @@ class PlaintextToEntries(TextToEntries):
|
|||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
DbEntry.EntryType.PLAINTEXT,
|
||||
DbEntry.EntrySource.COMPUTER,
|
||||
key="compiled",
|
||||
logger=logger,
|
||||
deletion_filenames=deletion_file_names,
|
||||
|
|
|
@ -78,6 +78,7 @@ class TextToEntries(ABC):
|
|||
self,
|
||||
current_entries: List[Entry],
|
||||
file_type: str,
|
||||
file_source: str,
|
||||
key="compiled",
|
||||
logger: logging.Logger = None,
|
||||
deletion_filenames: Set[str] = None,
|
||||
|
@ -95,7 +96,7 @@ class TextToEntries(ABC):
|
|||
if regenerate:
|
||||
with timer("Cleared existing dataset for regeneration in", logger):
|
||||
logger.debug(f"Deleting all entries for file type {file_type}")
|
||||
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
|
||||
num_deleted_entries = EntryAdapters.delete_all_entries_by_type(user, file_type)
|
||||
|
||||
hashes_to_process = set()
|
||||
with timer("Identified entries to add to database in", logger):
|
||||
|
@ -132,6 +133,7 @@ class TextToEntries(ABC):
|
|||
compiled=entry.compiled,
|
||||
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
||||
file_path=entry.file,
|
||||
file_source=file_source,
|
||||
file_type=file_type,
|
||||
hashed_value=entry_hash,
|
||||
corpus_id=entry.corpus_id,
|
||||
|
|
|
@ -204,11 +204,12 @@ def setup(
|
|||
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
|
||||
)
|
||||
|
||||
file_names = [file_name for file_name in files]
|
||||
if files:
|
||||
file_names = [file_name for file_name in files]
|
||||
|
||||
logger.info(
|
||||
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
|
||||
)
|
||||
logger.info(
|
||||
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
|
||||
)
|
||||
|
||||
|
||||
def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_empty_file_raises_error(
|
||||
def test_text_search_setup_with_empty_file_creates_no_entries(
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
|
|
Loading…
Reference in a new issue