Rename DbModels Embeddings, EmbeddingsAdapter to Entry, EntryAdapter

Improves readability as name has closer match to underlying
constructs

- Entry is any atomic item indexed by Khoj. This can be an org-mode
  entry, a markdown section, a PDF or Notion page etc.

- Embeddings are semantic vectors generated by the search ML model
  that encodes for meaning contained in an entries text.

- An "Entry" contains "Embeddings" vectors but also other metadata
  about the entry like filename etc.
This commit is contained in:
Debanjum Singh Solanky 2023-10-31 18:50:54 -07:00
parent 54a387326c
commit bcbee05a9e
15 changed files with 115 additions and 87 deletions

View file

@ -27,7 +27,7 @@ from database.models import (
KhojApiUser,
NotionConfig,
GithubConfig,
Embeddings,
Entry,
GithubRepoConfig,
Conversation,
ConversationProcessorConfig,
@ -286,54 +286,54 @@ class ConversationAdapters:
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()
class EmbeddingsAdapters:
class EntryAdapters:
word_filer = WordFilter()
file_filter = FileFilter()
date_filter = DateFilter()
@staticmethod
def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool:
return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists()
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod
def delete_embedding_by_file(user: KhojUser, file_path: str):
deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete()
def delete_entry_by_file(user: KhojUser, file_path: str):
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
return deleted_count
@staticmethod
def delete_all_embeddings(user: KhojUser, file_type: str):
deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete()
def delete_all_entries(user: KhojUser, file_type: str):
deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete()
return deleted_count
@staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod
def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]):
Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete()
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@staticmethod
def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date):
return embeddings.filter(
embeddingsdates__date__gte=start_date,
embeddingsdates__date__lte=end_date,
def get_entries_by_date_filter(entry: BaseManager[Entry], start_date: date, end_date: date):
return entry.filter(
entrydates__date__gte=start_date,
entrydates__date__lte=end_date,
)
@staticmethod
async def user_has_embeddings(user: KhojUser):
return await Embeddings.objects.filter(user=user).aexists()
async def user_has_entries(user: KhojUser):
return await Entry.objects.filter(user=user).aexists()
@staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
q_filter_terms = Q()
explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query)
file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query)
date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query)
explicit_word_terms = EntryAdapters.word_filer.get_filter_terms(query)
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Embeddings.objects.filter(user=user)
return Entry.objects.filter(user=user)
for term in explicit_word_terms:
if term.startswith("+"):
@ -354,32 +354,32 @@ class EmbeddingsAdapters:
if min_date is not None:
# Convert the min_date timestamp to yyyy-mm-dd format
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
q_filter_terms &= Q(entry_dates__date__gte=formatted_min_date)
if max_date is not None:
# Convert the max_date timestamp to yyyy-mm-dd format
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
q_filter_terms &= Q(entry_dates__date__lte=formatted_max_date)
relevant_embeddings = Embeddings.objects.filter(user=user).filter(
relevant_entries = Entry.objects.filter(user=user).filter(
q_filter_terms,
)
if file_type_filter:
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
return relevant_embeddings
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
return relevant_entries
@staticmethod
def search_with_embeddings(
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
):
relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_embeddings = relevant_embeddings.filter(user=user).annotate(
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_entries = relevant_entries.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings)
)
if file_type_filter:
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
relevant_embeddings = relevant_embeddings.order_by("distance")
return relevant_embeddings[:max_results]
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
relevant_entries = relevant_entries.order_by("distance")
return relevant_entries[:max_results]
@staticmethod
def get_unique_file_types(user: KhojUser):
return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct()
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()

View file

@ -0,0 +1,30 @@
# Generated by Django 4.2.5 on 2023-10-26 23:52
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0009_khojapiuser"),
]
operations = [
migrations.RenameModel(
old_name="Embeddings",
new_name="Entry",
),
migrations.RenameModel(
old_name="EmbeddingsDates",
new_name="EntryDates",
),
migrations.RenameField(
model_name="entrydates",
old_name="embeddings",
new_name="entry",
),
migrations.RenameIndex(
model_name="entrydates",
new_name="database_en_date_8d823c_idx",
old_name="database_em_date_a1ba47_idx",
),
]

View file

@ -114,8 +114,8 @@ class Conversation(BaseModel):
conversation_log = models.JSONField(default=dict)
class Embeddings(BaseModel):
class EmbeddingsType(models.TextChoices):
class Entry(BaseModel):
class EntryType(models.TextChoices):
IMAGE = "image"
PDF = "pdf"
PLAINTEXT = "plaintext"
@ -130,7 +130,7 @@ class Embeddings(BaseModel):
raw = models.TextField()
compiled = models.TextField()
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
file_type = models.CharField(max_length=30, choices=EmbeddingsType.choices, default=EmbeddingsType.PLAINTEXT)
file_type = models.CharField(max_length=30, choices=EntryType.choices, default=EntryType.PLAINTEXT)
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
url = models.URLField(max_length=400, default=None, null=True, blank=True)
@ -138,9 +138,9 @@ class Embeddings(BaseModel):
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
class EmbeddingsDates(BaseModel):
class EntryDates(BaseModel):
date = models.DateField()
embeddings = models.ForeignKey(Embeddings, on_delete=models.CASCADE, related_name="embeddings_dates")
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
class Meta:
indexes = [

View file

@ -13,8 +13,7 @@ from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.rawconfig import Entry
from database.models import Embeddings, GithubConfig, KhojUser
from database.models import Entry as DbEntry, GithubConfig, KhojUser
logger = logging.getLogger(__name__)
@ -103,7 +102,7 @@ class GithubToJsonl(TextEmbeddings):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, Embeddings.EmbeddingsType.GITHUB, key="compiled", logger=logger, user=user
current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user
)
return num_new_embeddings, num_deleted_embeddings

View file

@ -10,7 +10,7 @@ from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser
from database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__)
@ -46,7 +46,7 @@ class MarkdownToJsonl(TextEmbeddings):
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
Embeddings.EmbeddingsType.MARKDOWN,
DbEntry.EntryType.MARKDOWN,
"compiled",
logger,
deletion_file_names,

View file

@ -10,7 +10,7 @@ from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, NotionContentConfig
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser, NotionConfig
from database.models import Entry as DbEntry, KhojUser, NotionConfig
from enum import Enum
@ -250,7 +250,7 @@ class NotionToJsonl(TextEmbeddings):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, Embeddings.EmbeddingsType.NOTION, key="compiled", logger=logger, user=user
current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user
)
return num_new_embeddings, num_deleted_embeddings

View file

@ -9,7 +9,7 @@ from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry
from khoj.utils import state
from database.models import Embeddings, KhojUser
from database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__)
@ -47,7 +47,7 @@ class OrgToJsonl(TextEmbeddings):
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
Embeddings.EmbeddingsType.ORG,
DbEntry.EntryType.ORG,
"compiled",
logger,
deletion_file_names,

View file

@ -11,7 +11,7 @@ from langchain.document_loaders import PyMuPDFLoader
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser
from database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__)
@ -45,7 +45,7 @@ class PdfToJsonl(TextEmbeddings):
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
Embeddings.EmbeddingsType.PDF,
DbEntry.EntryType.PDF,
"compiled",
logger,
deletion_file_names,

View file

@ -9,7 +9,7 @@ from bs4 import BeautifulSoup
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser
from database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__)
@ -55,7 +55,7 @@ class PlaintextToJsonl(TextEmbeddings):
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
Embeddings.EmbeddingsType.PLAINTEXT,
DbEntry.EntryType.PLAINTEXT,
key="compiled",
logger=logger,
deletion_filenames=deletion_file_names,

View file

@ -12,8 +12,8 @@ from khoj.utils.helpers import timer, batcher
from khoj.utils.rawconfig import Entry
from khoj.processor.embeddings import EmbeddingsModel
from khoj.search_filter.date_filter import DateFilter
from database.models import KhojUser, Embeddings, EmbeddingsDates
from database.adapters import EmbeddingsAdapters
from database.models import KhojUser, Entry as DbEntry, EntryDates
from database.adapters import EntryAdapters
logger = logging.getLogger(__name__)
@ -94,14 +94,14 @@ class TextEmbeddings(ABC):
with timer("Preparing dataset for regeneration", logger):
if regenerate:
logger.debug(f"Deleting all embeddings for file type {file_type}")
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type)
num_new_embeddings = 0
with timer("Identify hashes for adding new entries", logger):
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
hashes_for_file = hashes_by_file[file]
hashes_to_process = set()
existing_entries = Embeddings.objects.filter(
existing_entries = DbEntry.objects.filter(
user=user, hashed_value__in=hashes_for_file, file_type=file_type
)
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
@ -124,7 +124,7 @@ class TextEmbeddings(ABC):
for entry_hash, embedding in entry_batch:
entry = hash_to_current_entries[entry_hash]
batch_embeddings_to_create.append(
Embeddings(
DbEntry(
user=user,
embeddings=embedding,
raw=entry.raw,
@ -136,7 +136,7 @@ class TextEmbeddings(ABC):
corpus_id=entry.corpus_id,
)
)
new_embeddings = Embeddings.objects.bulk_create(batch_embeddings_to_create)
new_embeddings = DbEntry.objects.bulk_create(batch_embeddings_to_create)
logger.debug(f"Created {len(new_embeddings)} new embeddings")
num_new_embeddings += len(new_embeddings)
@ -146,26 +146,26 @@ class TextEmbeddings(ABC):
dates = self.date_filter.extract_dates(embedding.raw)
for date in dates:
dates_to_create.append(
EmbeddingsDates(
EntryDates(
date=date,
embeddings=embedding,
)
)
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
new_dates = EntryDates.objects.bulk_create(dates_to_create)
if len(new_dates) > 0:
logger.debug(f"Created {len(new_dates)} new date entries")
with timer("Identify hashes for removed entries", logger):
for file in hashes_by_file:
existing_entry_hashes = EmbeddingsAdapters.get_existing_entry_hashes_by_file(user, file)
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
num_deleted_embeddings += len(to_delete_entry_hashes)
EmbeddingsAdapters.delete_embedding_by_hash(user, hashed_values=list(to_delete_entry_hashes))
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
with timer("Identify hashes for deleting entries", logger):
if deletion_filenames is not None:
for file_path in deletion_filenames:
deleted_count = EmbeddingsAdapters.delete_embedding_by_file(user, file_path)
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_embeddings += deleted_count
return num_new_embeddings, num_deleted_embeddings

View file

@ -48,7 +48,7 @@ from khoj.processor.conversation.gpt4all.chat_model import extract_questions_off
from fastapi.requests import Request
from database import adapters
from database.adapters import EmbeddingsAdapters, ConversationAdapters
from database.adapters import EntryAdapters, ConversationAdapters
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
@ -129,7 +129,7 @@ if not state.demo:
@requires(["authenticated"])
def get_config_data(request: Request):
user = request.user.object
EmbeddingsAdapters.get_unique_file_types(user)
EntryAdapters.get_unique_file_types(user)
return state.config
@ -145,7 +145,7 @@ if not state.demo:
configuration_update_metadata = {}
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
if state.config.content_type is not None:
configuration_update_metadata["github"] = "github" in enabled_content
@ -241,9 +241,9 @@ if not state.demo:
raise ValueError(f"Invalid content type: {content_type}")
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_type)
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"}
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
@ -372,7 +372,7 @@ def get_config_types(
):
user = request.user.object
enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user)
enabled_file_types = EntryAdapters.get_unique_file_types(user)
configured_content_types = list(enabled_file_types)
@ -706,7 +706,7 @@ async def extract_references_and_questions(
if conversation_type == ConversationCommand.General:
return compiled_references, inferred_queries, q
if not await EmbeddingsAdapters.user_has_embeddings(user=user):
if not await EntryAdapters.user_has_entries(user=user):
logger.warning(
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
)

View file

@ -19,7 +19,7 @@ from khoj.utils.rawconfig import (
# Internal Packages
from khoj.utils import constants, state
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
from database.adapters import EntryAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
@ -84,7 +84,7 @@ if not state.demo:
@requires(["authenticated"], redirect="login_page")
def config_page(request: Request):
user = request.user.object
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
enabled_content = set(EntryAdapters.get_unique_file_types(user).all())
default_full_config = FullConfig(
content_type=None,
search_type=None,

View file

@ -6,31 +6,31 @@ from typing import List, Tuple, Type, Union, Dict
# External Packages
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from sentence_transformers import util
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
from khoj.utils.helpers import get_absolute_path, timer
from khoj.utils.models import BaseEncoder
from khoj.utils.state import SearchType
from khoj.utils.rawconfig import SearchResponse, Entry
from khoj.utils.jsonl import load_jsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from database.adapters import EmbeddingsAdapters
from database.models import KhojUser, Embeddings
from database.adapters import EntryAdapters
from database.models import KhojUser, Entry as DbEntry
logger = logging.getLogger(__name__)
search_type_to_embeddings_type = {
SearchType.Org.value: Embeddings.EmbeddingsType.ORG,
SearchType.Markdown.value: Embeddings.EmbeddingsType.MARKDOWN,
SearchType.Plaintext.value: Embeddings.EmbeddingsType.PLAINTEXT,
SearchType.Pdf.value: Embeddings.EmbeddingsType.PDF,
SearchType.Github.value: Embeddings.EmbeddingsType.GITHUB,
SearchType.Notion.value: Embeddings.EmbeddingsType.NOTION,
SearchType.Org.value: DbEntry.EntryType.ORG,
SearchType.Markdown.value: DbEntry.EntryType.MARKDOWN,
SearchType.Plaintext.value: DbEntry.EntryType.PLAINTEXT,
SearchType.Pdf.value: DbEntry.EntryType.PDF,
SearchType.Github.value: DbEntry.EntryType.GITHUB,
SearchType.Notion.value: DbEntry.EntryType.NOTION,
SearchType.All.value: None,
}
@ -121,7 +121,7 @@ async def query(
# Find relevant entries for the query
top_k = 10
with timer("Search Time", logger, state.device):
hits = EmbeddingsAdapters.search_with_embeddings(
hits = EntryAdapters.search_with_embeddings(
user=user,
embeddings=question_embedding,
max_results=top_k,

View file

@ -17,7 +17,7 @@ from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from database.models import KhojUser
from database.adapters import EmbeddingsAdapters
from database.adapters import EntryAdapters
# Test
@ -178,7 +178,7 @@ def test_get_configured_types_via_api(client, sample_org_data):
# Act
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
# Assert
assert list(enabled_types) == ["org"]

View file

@ -1,6 +1,5 @@
# System Packages
import logging
import locale
from pathlib import Path
import os
import asyncio
@ -14,7 +13,7 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.utils.fs_syncer import collect_files, get_org_files
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
from database.models import LocalOrgConfig, KhojUser, Entry, GithubConfig
logger = logging.getLogger(__name__)
@ -402,10 +401,10 @@ def test_text_search_setup_github(content_config: ContentConfig, default_user: K
)
# Assert
embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count()
embeddings = Entry.objects.filter(user=default_user, file_type="github").count()
assert embeddings > 1
def verify_embeddings(expected_count, user):
embeddings = Embeddings.objects.filter(user=user, file_type="org").count()
embeddings = Entry.objects.filter(user=user, file_type="org").count()
assert embeddings == expected_count