Rename DbModels Embeddings, EmbeddingsAdapter to Entry, EntryAdapter

Improves readability as name has closer match to underlying
constructs

- Entry is any atomic item indexed by Khoj. This can be an org-mode
  entry, a markdown section, a PDF or Notion page etc.

- Embeddings are semantic vectors generated by the search ML model
  that encodes for meaning contained in an entries text.

- An "Entry" contains "Embeddings" vectors but also other metadata
  about the entry like filename etc.
This commit is contained in:
Debanjum Singh Solanky 2023-10-31 18:50:54 -07:00
parent 54a387326c
commit bcbee05a9e
15 changed files with 115 additions and 87 deletions

View file

@ -27,7 +27,7 @@ from database.models import (
KhojApiUser, KhojApiUser,
NotionConfig, NotionConfig,
GithubConfig, GithubConfig,
Embeddings, Entry,
GithubRepoConfig, GithubRepoConfig,
Conversation, Conversation,
ConversationProcessorConfig, ConversationProcessorConfig,
@ -286,54 +286,54 @@ class ConversationAdapters:
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst() return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()
class EmbeddingsAdapters: class EntryAdapters:
word_filer = WordFilter() word_filer = WordFilter()
file_filter = FileFilter() file_filter = FileFilter()
date_filter = DateFilter() date_filter = DateFilter()
@staticmethod @staticmethod
def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool: def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists() return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod @staticmethod
def delete_embedding_by_file(user: KhojUser, file_path: str): def delete_entry_by_file(user: KhojUser, file_path: str):
deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete() deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
return deleted_count return deleted_count
@staticmethod @staticmethod
def delete_all_embeddings(user: KhojUser, file_type: str): def delete_all_entries(user: KhojUser, file_type: str):
deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete() deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete()
return deleted_count return deleted_count
@staticmethod @staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod @staticmethod
def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]): def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete() Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@staticmethod @staticmethod
def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date): def get_entries_by_date_filter(entry: BaseManager[Entry], start_date: date, end_date: date):
return embeddings.filter( return entry.filter(
embeddingsdates__date__gte=start_date, entrydates__date__gte=start_date,
embeddingsdates__date__lte=end_date, entrydates__date__lte=end_date,
) )
@staticmethod @staticmethod
async def user_has_embeddings(user: KhojUser): async def user_has_entries(user: KhojUser):
return await Embeddings.objects.filter(user=user).aexists() return await Entry.objects.filter(user=user).aexists()
@staticmethod @staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None): def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
q_filter_terms = Q() q_filter_terms = Q()
explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query) explicit_word_terms = EntryAdapters.word_filer.get_filter_terms(query)
file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query) file_filters = EntryAdapters.file_filter.get_filter_terms(query)
date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query) date_filters = EntryAdapters.date_filter.get_query_date_range(query)
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0: if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Embeddings.objects.filter(user=user) return Entry.objects.filter(user=user)
for term in explicit_word_terms: for term in explicit_word_terms:
if term.startswith("+"): if term.startswith("+"):
@ -354,32 +354,32 @@ class EmbeddingsAdapters:
if min_date is not None: if min_date is not None:
# Convert the min_date timestamp to yyyy-mm-dd format # Convert the min_date timestamp to yyyy-mm-dd format
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d") formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date) q_filter_terms &= Q(entry_dates__date__gte=formatted_min_date)
if max_date is not None: if max_date is not None:
# Convert the max_date timestamp to yyyy-mm-dd format # Convert the max_date timestamp to yyyy-mm-dd format
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d") formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date) q_filter_terms &= Q(entry_dates__date__lte=formatted_max_date)
relevant_embeddings = Embeddings.objects.filter(user=user).filter( relevant_entries = Entry.objects.filter(user=user).filter(
q_filter_terms, q_filter_terms,
) )
if file_type_filter: if file_type_filter:
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter) relevant_entries = relevant_entries.filter(file_type=file_type_filter)
return relevant_embeddings return relevant_entries
@staticmethod @staticmethod
def search_with_embeddings( def search_with_embeddings(
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
): ):
relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter) relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_embeddings = relevant_embeddings.filter(user=user).annotate( relevant_entries = relevant_entries.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings) distance=CosineDistance("embeddings", embeddings)
) )
if file_type_filter: if file_type_filter:
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter) relevant_entries = relevant_entries.filter(file_type=file_type_filter)
relevant_embeddings = relevant_embeddings.order_by("distance") relevant_entries = relevant_entries.order_by("distance")
return relevant_embeddings[:max_results] return relevant_entries[:max_results]
@staticmethod @staticmethod
def get_unique_file_types(user: KhojUser): def get_unique_file_types(user: KhojUser):
return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct() return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()

View file

@ -0,0 +1,30 @@
# Generated by Django 4.2.5 on 2023-10-26 23:52
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0009_khojapiuser"),
]
operations = [
migrations.RenameModel(
old_name="Embeddings",
new_name="Entry",
),
migrations.RenameModel(
old_name="EmbeddingsDates",
new_name="EntryDates",
),
migrations.RenameField(
model_name="entrydates",
old_name="embeddings",
new_name="entry",
),
migrations.RenameIndex(
model_name="entrydates",
new_name="database_en_date_8d823c_idx",
old_name="database_em_date_a1ba47_idx",
),
]

View file

@ -114,8 +114,8 @@ class Conversation(BaseModel):
conversation_log = models.JSONField(default=dict) conversation_log = models.JSONField(default=dict)
class Embeddings(BaseModel): class Entry(BaseModel):
class EmbeddingsType(models.TextChoices): class EntryType(models.TextChoices):
IMAGE = "image" IMAGE = "image"
PDF = "pdf" PDF = "pdf"
PLAINTEXT = "plaintext" PLAINTEXT = "plaintext"
@ -130,7 +130,7 @@ class Embeddings(BaseModel):
raw = models.TextField() raw = models.TextField()
compiled = models.TextField() compiled = models.TextField()
heading = models.CharField(max_length=1000, default=None, null=True, blank=True) heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
file_type = models.CharField(max_length=30, choices=EmbeddingsType.choices, default=EmbeddingsType.PLAINTEXT) file_type = models.CharField(max_length=30, choices=EntryType.choices, default=EntryType.PLAINTEXT)
file_path = models.CharField(max_length=400, default=None, null=True, blank=True) file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
file_name = models.CharField(max_length=400, default=None, null=True, blank=True) file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
url = models.URLField(max_length=400, default=None, null=True, blank=True) url = models.URLField(max_length=400, default=None, null=True, blank=True)
@ -138,9 +138,9 @@ class Embeddings(BaseModel):
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False) corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
class EmbeddingsDates(BaseModel): class EntryDates(BaseModel):
date = models.DateField() date = models.DateField()
embeddings = models.ForeignKey(Embeddings, on_delete=models.CASCADE, related_name="embeddings_dates") entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
class Meta: class Meta:
indexes = [ indexes = [

View file

@ -13,8 +13,7 @@ from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.rawconfig import Entry from database.models import Entry as DbEntry, GithubConfig, KhojUser
from database.models import Embeddings, GithubConfig, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -103,7 +102,7 @@ class GithubToJsonl(TextEmbeddings):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, Embeddings.EmbeddingsType.GITHUB, key="compiled", logger=logger, user=user current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user
) )
return num_new_embeddings, num_deleted_embeddings return num_new_embeddings, num_deleted_embeddings

View file

@ -10,7 +10,7 @@ from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.constants import empty_escape_sequences from khoj.utils.constants import empty_escape_sequences
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser from database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,7 +46,7 @@ class MarkdownToJsonl(TextEmbeddings):
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, current_entries,
Embeddings.EmbeddingsType.MARKDOWN, DbEntry.EntryType.MARKDOWN,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,

View file

@ -10,7 +10,7 @@ from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, NotionContentConfig from khoj.utils.rawconfig import Entry, NotionContentConfig
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser, NotionConfig from database.models import Entry as DbEntry, KhojUser, NotionConfig
from enum import Enum from enum import Enum
@ -250,7 +250,7 @@ class NotionToJsonl(TextEmbeddings):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, Embeddings.EmbeddingsType.NOTION, key="compiled", logger=logger, user=user current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user
) )
return num_new_embeddings, num_deleted_embeddings return num_new_embeddings, num_deleted_embeddings

View file

@ -9,7 +9,7 @@ from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from khoj.utils import state from khoj.utils import state
from database.models import Embeddings, KhojUser from database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,7 +47,7 @@ class OrgToJsonl(TextEmbeddings):
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, current_entries,
Embeddings.EmbeddingsType.ORG, DbEntry.EntryType.ORG,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,

View file

@ -11,7 +11,7 @@ from langchain.document_loaders import PyMuPDFLoader
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser from database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,7 +45,7 @@ class PdfToJsonl(TextEmbeddings):
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, current_entries,
Embeddings.EmbeddingsType.PDF, DbEntry.EntryType.PDF,
"compiled", "compiled",
logger, logger,
deletion_file_names, deletion_file_names,

View file

@ -9,7 +9,7 @@ from bs4 import BeautifulSoup
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser from database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,7 +55,7 @@ class PlaintextToJsonl(TextEmbeddings):
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, current_entries,
Embeddings.EmbeddingsType.PLAINTEXT, DbEntry.EntryType.PLAINTEXT,
key="compiled", key="compiled",
logger=logger, logger=logger,
deletion_filenames=deletion_file_names, deletion_filenames=deletion_file_names,

View file

@ -12,8 +12,8 @@ from khoj.utils.helpers import timer, batcher
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from khoj.processor.embeddings import EmbeddingsModel from khoj.processor.embeddings import EmbeddingsModel
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from database.models import KhojUser, Embeddings, EmbeddingsDates from database.models import KhojUser, Entry as DbEntry, EntryDates
from database.adapters import EmbeddingsAdapters from database.adapters import EntryAdapters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -94,14 +94,14 @@ class TextEmbeddings(ABC):
with timer("Preparing dataset for regeneration", logger): with timer("Preparing dataset for regeneration", logger):
if regenerate: if regenerate:
logger.debug(f"Deleting all embeddings for file type {file_type}") logger.debug(f"Deleting all embeddings for file type {file_type}")
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type) num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type)
num_new_embeddings = 0 num_new_embeddings = 0
with timer("Identify hashes for adding new entries", logger): with timer("Identify hashes for adding new entries", logger):
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"): for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
hashes_for_file = hashes_by_file[file] hashes_for_file = hashes_by_file[file]
hashes_to_process = set() hashes_to_process = set()
existing_entries = Embeddings.objects.filter( existing_entries = DbEntry.objects.filter(
user=user, hashed_value__in=hashes_for_file, file_type=file_type user=user, hashed_value__in=hashes_for_file, file_type=file_type
) )
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries]) existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
@ -124,7 +124,7 @@ class TextEmbeddings(ABC):
for entry_hash, embedding in entry_batch: for entry_hash, embedding in entry_batch:
entry = hash_to_current_entries[entry_hash] entry = hash_to_current_entries[entry_hash]
batch_embeddings_to_create.append( batch_embeddings_to_create.append(
Embeddings( DbEntry(
user=user, user=user,
embeddings=embedding, embeddings=embedding,
raw=entry.raw, raw=entry.raw,
@ -136,7 +136,7 @@ class TextEmbeddings(ABC):
corpus_id=entry.corpus_id, corpus_id=entry.corpus_id,
) )
) )
new_embeddings = Embeddings.objects.bulk_create(batch_embeddings_to_create) new_embeddings = DbEntry.objects.bulk_create(batch_embeddings_to_create)
logger.debug(f"Created {len(new_embeddings)} new embeddings") logger.debug(f"Created {len(new_embeddings)} new embeddings")
num_new_embeddings += len(new_embeddings) num_new_embeddings += len(new_embeddings)
@ -146,26 +146,26 @@ class TextEmbeddings(ABC):
dates = self.date_filter.extract_dates(embedding.raw) dates = self.date_filter.extract_dates(embedding.raw)
for date in dates: for date in dates:
dates_to_create.append( dates_to_create.append(
EmbeddingsDates( EntryDates(
date=date, date=date,
embeddings=embedding, embeddings=embedding,
) )
) )
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create) new_dates = EntryDates.objects.bulk_create(dates_to_create)
if len(new_dates) > 0: if len(new_dates) > 0:
logger.debug(f"Created {len(new_dates)} new date entries") logger.debug(f"Created {len(new_dates)} new date entries")
with timer("Identify hashes for removed entries", logger): with timer("Identify hashes for removed entries", logger):
for file in hashes_by_file: for file in hashes_by_file:
existing_entry_hashes = EmbeddingsAdapters.get_existing_entry_hashes_by_file(user, file) existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file] to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
num_deleted_embeddings += len(to_delete_entry_hashes) num_deleted_embeddings += len(to_delete_entry_hashes)
EmbeddingsAdapters.delete_embedding_by_hash(user, hashed_values=list(to_delete_entry_hashes)) EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
with timer("Identify hashes for deleting entries", logger): with timer("Identify hashes for deleting entries", logger):
if deletion_filenames is not None: if deletion_filenames is not None:
for file_path in deletion_filenames: for file_path in deletion_filenames:
deleted_count = EmbeddingsAdapters.delete_embedding_by_file(user, file_path) deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_embeddings += deleted_count num_deleted_embeddings += deleted_count
return num_new_embeddings, num_deleted_embeddings return num_new_embeddings, num_deleted_embeddings

View file

@ -48,7 +48,7 @@ from khoj.processor.conversation.gpt4all.chat_model import extract_questions_off
from fastapi.requests import Request from fastapi.requests import Request
from database import adapters from database import adapters
from database.adapters import EmbeddingsAdapters, ConversationAdapters from database.adapters import EntryAdapters, ConversationAdapters
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
@ -129,7 +129,7 @@ if not state.demo:
@requires(["authenticated"]) @requires(["authenticated"])
def get_config_data(request: Request): def get_config_data(request: Request):
user = request.user.object user = request.user.object
EmbeddingsAdapters.get_unique_file_types(user) EntryAdapters.get_unique_file_types(user)
return state.config return state.config
@ -145,7 +145,7 @@ if not state.demo:
configuration_update_metadata = {} configuration_update_metadata = {}
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user) enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
if state.config.content_type is not None: if state.config.content_type is not None:
configuration_update_metadata["github"] = "github" in enabled_content configuration_update_metadata["github"] = "github" in enabled_content
@ -241,9 +241,9 @@ if not state.demo:
raise ValueError(f"Invalid content type: {content_type}") raise ValueError(f"Invalid content type: {content_type}")
await content_object.objects.filter(user=user).adelete() await content_object.objects.filter(user=user).adelete()
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type) await sync_to_async(EntryAdapters.delete_all_entries)(user, content_type)
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user) enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"} return {"status": "ok"}
@api.post("/delete/config/data/processor/conversation/openai", status_code=200) @api.post("/delete/config/data/processor/conversation/openai", status_code=200)
@ -372,7 +372,7 @@ def get_config_types(
): ):
user = request.user.object user = request.user.object
enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user) enabled_file_types = EntryAdapters.get_unique_file_types(user)
configured_content_types = list(enabled_file_types) configured_content_types = list(enabled_file_types)
@ -706,7 +706,7 @@ async def extract_references_and_questions(
if conversation_type == ConversationCommand.General: if conversation_type == ConversationCommand.General:
return compiled_references, inferred_queries, q return compiled_references, inferred_queries, q
if not await EmbeddingsAdapters.user_has_embeddings(user=user): if not await EntryAdapters.user_has_entries(user=user):
logger.warning( logger.warning(
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes." "No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
) )

View file

@ -19,7 +19,7 @@ from khoj.utils.rawconfig import (
# Internal Packages # Internal Packages
from khoj.utils import constants, state from khoj.utils import constants, state
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters from database.adapters import EntryAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
@ -84,7 +84,7 @@ if not state.demo:
@requires(["authenticated"], redirect="login_page") @requires(["authenticated"], redirect="login_page")
def config_page(request: Request): def config_page(request: Request):
user = request.user.object user = request.user.object
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all()) enabled_content = set(EntryAdapters.get_unique_file_types(user).all())
default_full_config = FullConfig( default_full_config = FullConfig(
content_type=None, content_type=None,
search_type=None, search_type=None,

View file

@ -6,31 +6,31 @@ from typing import List, Tuple, Type, Union, Dict
# External Packages # External Packages
import torch import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util from sentence_transformers import util
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
# Internal Packages # Internal Packages
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer from khoj.utils.helpers import get_absolute_path, timer
from khoj.utils.models import BaseEncoder from khoj.utils.models import BaseEncoder
from khoj.utils.state import SearchType from khoj.utils.state import SearchType
from khoj.utils.rawconfig import SearchResponse, Entry from khoj.utils.rawconfig import SearchResponse, Entry
from khoj.utils.jsonl import load_jsonl from khoj.utils.jsonl import load_jsonl
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEmbeddings
from database.adapters import EmbeddingsAdapters from database.adapters import EntryAdapters
from database.models import KhojUser, Embeddings from database.models import KhojUser, Entry as DbEntry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
search_type_to_embeddings_type = { search_type_to_embeddings_type = {
SearchType.Org.value: Embeddings.EmbeddingsType.ORG, SearchType.Org.value: DbEntry.EntryType.ORG,
SearchType.Markdown.value: Embeddings.EmbeddingsType.MARKDOWN, SearchType.Markdown.value: DbEntry.EntryType.MARKDOWN,
SearchType.Plaintext.value: Embeddings.EmbeddingsType.PLAINTEXT, SearchType.Plaintext.value: DbEntry.EntryType.PLAINTEXT,
SearchType.Pdf.value: Embeddings.EmbeddingsType.PDF, SearchType.Pdf.value: DbEntry.EntryType.PDF,
SearchType.Github.value: Embeddings.EmbeddingsType.GITHUB, SearchType.Github.value: DbEntry.EntryType.GITHUB,
SearchType.Notion.value: Embeddings.EmbeddingsType.NOTION, SearchType.Notion.value: DbEntry.EntryType.NOTION,
SearchType.All.value: None, SearchType.All.value: None,
} }
@ -121,7 +121,7 @@ async def query(
# Find relevant entries for the query # Find relevant entries for the query
top_k = 10 top_k = 10
with timer("Search Time", logger, state.device): with timer("Search Time", logger, state.device):
hits = EmbeddingsAdapters.search_with_embeddings( hits = EntryAdapters.search_with_embeddings(
user=user, user=user,
embeddings=question_embedding, embeddings=question_embedding,
max_results=top_k, max_results=top_k,

View file

@ -17,7 +17,7 @@ from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from database.models import KhojUser from database.models import KhojUser
from database.adapters import EmbeddingsAdapters from database.adapters import EntryAdapters
# Test # Test
@ -178,7 +178,7 @@ def test_get_configured_types_via_api(client, sample_org_data):
# Act # Act
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False) text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True) enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
# Assert # Assert
assert list(enabled_types) == ["org"] assert list(enabled_types) == ["org"]

View file

@ -1,6 +1,5 @@
# System Packages # System Packages
import logging import logging
import locale
from pathlib import Path from pathlib import Path
import os import os
import asyncio import asyncio
@ -14,7 +13,7 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.utils.fs_syncer import collect_files, get_org_files from khoj.utils.fs_syncer import collect_files, get_org_files
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig from database.models import LocalOrgConfig, KhojUser, Entry, GithubConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -402,10 +401,10 @@ def test_text_search_setup_github(content_config: ContentConfig, default_user: K
) )
# Assert # Assert
embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count() embeddings = Entry.objects.filter(user=default_user, file_type="github").count()
assert embeddings > 1 assert embeddings > 1
def verify_embeddings(expected_count, user): def verify_embeddings(expected_count, user):
embeddings = Embeddings.objects.filter(user=user, file_type="org").count() embeddings = Entry.objects.filter(user=user, file_type="org").count()
assert embeddings == expected_count assert embeddings == expected_count