mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Rename DbModels Embeddings, EmbeddingsAdapter to Entry, EntryAdapter
Improves readability as name has closer match to underlying constructs - Entry is any atomic item indexed by Khoj. This can be an org-mode entry, a markdown section, a PDF or Notion page etc. - Embeddings are semantic vectors generated by the search ML model that encodes for meaning contained in an entries text. - An "Entry" contains "Embeddings" vectors but also other metadata about the entry like filename etc.
This commit is contained in:
parent
54a387326c
commit
bcbee05a9e
15 changed files with 115 additions and 87 deletions
|
@ -27,7 +27,7 @@ from database.models import (
|
||||||
KhojApiUser,
|
KhojApiUser,
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
GithubConfig,
|
GithubConfig,
|
||||||
Embeddings,
|
Entry,
|
||||||
GithubRepoConfig,
|
GithubRepoConfig,
|
||||||
Conversation,
|
Conversation,
|
||||||
ConversationProcessorConfig,
|
ConversationProcessorConfig,
|
||||||
|
@ -286,54 +286,54 @@ class ConversationAdapters:
|
||||||
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()
|
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsAdapters:
|
class EntryAdapters:
|
||||||
word_filer = WordFilter()
|
word_filer = WordFilter()
|
||||||
file_filter = FileFilter()
|
file_filter = FileFilter()
|
||||||
date_filter = DateFilter()
|
date_filter = DateFilter()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool:
|
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
|
||||||
return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists()
|
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_embedding_by_file(user: KhojUser, file_path: str):
|
def delete_entry_by_file(user: KhojUser, file_path: str):
|
||||||
deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete()
|
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_all_embeddings(user: KhojUser, file_type: str):
|
def delete_all_entries(user: KhojUser, file_type: str):
|
||||||
deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete()
|
deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete()
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
||||||
return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]):
|
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
|
||||||
Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date):
|
def get_entries_by_date_filter(entry: BaseManager[Entry], start_date: date, end_date: date):
|
||||||
return embeddings.filter(
|
return entry.filter(
|
||||||
embeddingsdates__date__gte=start_date,
|
entrydates__date__gte=start_date,
|
||||||
embeddingsdates__date__lte=end_date,
|
entrydates__date__lte=end_date,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def user_has_embeddings(user: KhojUser):
|
async def user_has_entries(user: KhojUser):
|
||||||
return await Embeddings.objects.filter(user=user).aexists()
|
return await Entry.objects.filter(user=user).aexists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
|
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
|
||||||
q_filter_terms = Q()
|
q_filter_terms = Q()
|
||||||
|
|
||||||
explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query)
|
explicit_word_terms = EntryAdapters.word_filer.get_filter_terms(query)
|
||||||
file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query)
|
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
|
||||||
date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query)
|
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
|
||||||
|
|
||||||
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
||||||
return Embeddings.objects.filter(user=user)
|
return Entry.objects.filter(user=user)
|
||||||
|
|
||||||
for term in explicit_word_terms:
|
for term in explicit_word_terms:
|
||||||
if term.startswith("+"):
|
if term.startswith("+"):
|
||||||
|
@ -354,32 +354,32 @@ class EmbeddingsAdapters:
|
||||||
if min_date is not None:
|
if min_date is not None:
|
||||||
# Convert the min_date timestamp to yyyy-mm-dd format
|
# Convert the min_date timestamp to yyyy-mm-dd format
|
||||||
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
|
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
|
||||||
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
|
q_filter_terms &= Q(entry_dates__date__gte=formatted_min_date)
|
||||||
if max_date is not None:
|
if max_date is not None:
|
||||||
# Convert the max_date timestamp to yyyy-mm-dd format
|
# Convert the max_date timestamp to yyyy-mm-dd format
|
||||||
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
|
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
|
||||||
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
|
q_filter_terms &= Q(entry_dates__date__lte=formatted_max_date)
|
||||||
|
|
||||||
relevant_embeddings = Embeddings.objects.filter(user=user).filter(
|
relevant_entries = Entry.objects.filter(user=user).filter(
|
||||||
q_filter_terms,
|
q_filter_terms,
|
||||||
)
|
)
|
||||||
if file_type_filter:
|
if file_type_filter:
|
||||||
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
|
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
||||||
return relevant_embeddings
|
return relevant_entries
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def search_with_embeddings(
|
def search_with_embeddings(
|
||||||
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
|
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
|
||||||
):
|
):
|
||||||
relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter)
|
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
|
||||||
relevant_embeddings = relevant_embeddings.filter(user=user).annotate(
|
relevant_entries = relevant_entries.filter(user=user).annotate(
|
||||||
distance=CosineDistance("embeddings", embeddings)
|
distance=CosineDistance("embeddings", embeddings)
|
||||||
)
|
)
|
||||||
if file_type_filter:
|
if file_type_filter:
|
||||||
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
|
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
||||||
relevant_embeddings = relevant_embeddings.order_by("distance")
|
relevant_entries = relevant_entries.order_by("distance")
|
||||||
return relevant_embeddings[:max_results]
|
return relevant_entries[:max_results]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_unique_file_types(user: KhojUser):
|
def get_unique_file_types(user: KhojUser):
|
||||||
return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Generated by Django 4.2.5 on 2023-10-26 23:52
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0009_khojapiuser"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RenameModel(
|
||||||
|
old_name="Embeddings",
|
||||||
|
new_name="Entry",
|
||||||
|
),
|
||||||
|
migrations.RenameModel(
|
||||||
|
old_name="EmbeddingsDates",
|
||||||
|
new_name="EntryDates",
|
||||||
|
),
|
||||||
|
migrations.RenameField(
|
||||||
|
model_name="entrydates",
|
||||||
|
old_name="embeddings",
|
||||||
|
new_name="entry",
|
||||||
|
),
|
||||||
|
migrations.RenameIndex(
|
||||||
|
model_name="entrydates",
|
||||||
|
new_name="database_en_date_8d823c_idx",
|
||||||
|
old_name="database_em_date_a1ba47_idx",
|
||||||
|
),
|
||||||
|
]
|
|
@ -114,8 +114,8 @@ class Conversation(BaseModel):
|
||||||
conversation_log = models.JSONField(default=dict)
|
conversation_log = models.JSONField(default=dict)
|
||||||
|
|
||||||
|
|
||||||
class Embeddings(BaseModel):
|
class Entry(BaseModel):
|
||||||
class EmbeddingsType(models.TextChoices):
|
class EntryType(models.TextChoices):
|
||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
PDF = "pdf"
|
PDF = "pdf"
|
||||||
PLAINTEXT = "plaintext"
|
PLAINTEXT = "plaintext"
|
||||||
|
@ -130,7 +130,7 @@ class Embeddings(BaseModel):
|
||||||
raw = models.TextField()
|
raw = models.TextField()
|
||||||
compiled = models.TextField()
|
compiled = models.TextField()
|
||||||
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
|
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
|
||||||
file_type = models.CharField(max_length=30, choices=EmbeddingsType.choices, default=EmbeddingsType.PLAINTEXT)
|
file_type = models.CharField(max_length=30, choices=EntryType.choices, default=EntryType.PLAINTEXT)
|
||||||
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
|
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||||
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||||
url = models.URLField(max_length=400, default=None, null=True, blank=True)
|
url = models.URLField(max_length=400, default=None, null=True, blank=True)
|
||||||
|
@ -138,9 +138,9 @@ class Embeddings(BaseModel):
|
||||||
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsDates(BaseModel):
|
class EntryDates(BaseModel):
|
||||||
date = models.DateField()
|
date = models.DateField()
|
||||||
embeddings = models.ForeignKey(Embeddings, on_delete=models.CASCADE, related_name="embeddings_dates")
|
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
indexes = [
|
indexes = [
|
||||||
|
|
|
@ -13,8 +13,7 @@ from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
|
||||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.rawconfig import Entry
|
from database.models import Entry as DbEntry, GithubConfig, KhojUser
|
||||||
from database.models import Embeddings, GithubConfig, KhojUser
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -103,7 +102,7 @@ class GithubToJsonl(TextEmbeddings):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
current_entries, Embeddings.EmbeddingsType.GITHUB, key="compiled", logger=logger, user=user
|
current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
return num_new_embeddings, num_deleted_embeddings
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
|
@ -10,7 +10,7 @@ from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.constants import empty_escape_sequences
|
from khoj.utils.constants import empty_escape_sequences
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
from database.models import Embeddings, KhojUser
|
from database.models import Entry as DbEntry, KhojUser
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -46,7 +46,7 @@ class MarkdownToJsonl(TextEmbeddings):
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
current_entries,
|
current_entries,
|
||||||
Embeddings.EmbeddingsType.MARKDOWN,
|
DbEntry.EntryType.MARKDOWN,
|
||||||
"compiled",
|
"compiled",
|
||||||
logger,
|
logger,
|
||||||
deletion_file_names,
|
deletion_file_names,
|
||||||
|
|
|
@ -10,7 +10,7 @@ from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import Entry, NotionContentConfig
|
from khoj.utils.rawconfig import Entry, NotionContentConfig
|
||||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
from database.models import Embeddings, KhojUser, NotionConfig
|
from database.models import Entry as DbEntry, KhojUser, NotionConfig
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
@ -250,7 +250,7 @@ class NotionToJsonl(TextEmbeddings):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
current_entries, Embeddings.EmbeddingsType.NOTION, key="compiled", logger=logger, user=user
|
current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
return num_new_embeddings, num_deleted_embeddings
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
|
@ -9,7 +9,7 @@ from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from database.models import Embeddings, KhojUser
|
from database.models import Entry as DbEntry, KhojUser
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -47,7 +47,7 @@ class OrgToJsonl(TextEmbeddings):
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
current_entries,
|
current_entries,
|
||||||
Embeddings.EmbeddingsType.ORG,
|
DbEntry.EntryType.ORG,
|
||||||
"compiled",
|
"compiled",
|
||||||
logger,
|
logger,
|
||||||
deletion_file_names,
|
deletion_file_names,
|
||||||
|
|
|
@ -11,7 +11,7 @@ from langchain.document_loaders import PyMuPDFLoader
|
||||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
from database.models import Embeddings, KhojUser
|
from database.models import Entry as DbEntry, KhojUser
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -45,7 +45,7 @@ class PdfToJsonl(TextEmbeddings):
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
current_entries,
|
current_entries,
|
||||||
Embeddings.EmbeddingsType.PDF,
|
DbEntry.EntryType.PDF,
|
||||||
"compiled",
|
"compiled",
|
||||||
logger,
|
logger,
|
||||||
deletion_file_names,
|
deletion_file_names,
|
||||||
|
|
|
@ -9,7 +9,7 @@ from bs4 import BeautifulSoup
|
||||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
from database.models import Embeddings, KhojUser
|
from database.models import Entry as DbEntry, KhojUser
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -55,7 +55,7 @@ class PlaintextToJsonl(TextEmbeddings):
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
current_entries,
|
current_entries,
|
||||||
Embeddings.EmbeddingsType.PLAINTEXT,
|
DbEntry.EntryType.PLAINTEXT,
|
||||||
key="compiled",
|
key="compiled",
|
||||||
logger=logger,
|
logger=logger,
|
||||||
deletion_filenames=deletion_file_names,
|
deletion_filenames=deletion_file_names,
|
||||||
|
|
|
@ -12,8 +12,8 @@ from khoj.utils.helpers import timer, batcher
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
from khoj.processor.embeddings import EmbeddingsModel
|
from khoj.processor.embeddings import EmbeddingsModel
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from database.models import KhojUser, Embeddings, EmbeddingsDates
|
from database.models import KhojUser, Entry as DbEntry, EntryDates
|
||||||
from database.adapters import EmbeddingsAdapters
|
from database.adapters import EntryAdapters
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -94,14 +94,14 @@ class TextEmbeddings(ABC):
|
||||||
with timer("Preparing dataset for regeneration", logger):
|
with timer("Preparing dataset for regeneration", logger):
|
||||||
if regenerate:
|
if regenerate:
|
||||||
logger.debug(f"Deleting all embeddings for file type {file_type}")
|
logger.debug(f"Deleting all embeddings for file type {file_type}")
|
||||||
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
|
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type)
|
||||||
|
|
||||||
num_new_embeddings = 0
|
num_new_embeddings = 0
|
||||||
with timer("Identify hashes for adding new entries", logger):
|
with timer("Identify hashes for adding new entries", logger):
|
||||||
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
|
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
|
||||||
hashes_for_file = hashes_by_file[file]
|
hashes_for_file = hashes_by_file[file]
|
||||||
hashes_to_process = set()
|
hashes_to_process = set()
|
||||||
existing_entries = Embeddings.objects.filter(
|
existing_entries = DbEntry.objects.filter(
|
||||||
user=user, hashed_value__in=hashes_for_file, file_type=file_type
|
user=user, hashed_value__in=hashes_for_file, file_type=file_type
|
||||||
)
|
)
|
||||||
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
||||||
|
@ -124,7 +124,7 @@ class TextEmbeddings(ABC):
|
||||||
for entry_hash, embedding in entry_batch:
|
for entry_hash, embedding in entry_batch:
|
||||||
entry = hash_to_current_entries[entry_hash]
|
entry = hash_to_current_entries[entry_hash]
|
||||||
batch_embeddings_to_create.append(
|
batch_embeddings_to_create.append(
|
||||||
Embeddings(
|
DbEntry(
|
||||||
user=user,
|
user=user,
|
||||||
embeddings=embedding,
|
embeddings=embedding,
|
||||||
raw=entry.raw,
|
raw=entry.raw,
|
||||||
|
@ -136,7 +136,7 @@ class TextEmbeddings(ABC):
|
||||||
corpus_id=entry.corpus_id,
|
corpus_id=entry.corpus_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
new_embeddings = Embeddings.objects.bulk_create(batch_embeddings_to_create)
|
new_embeddings = DbEntry.objects.bulk_create(batch_embeddings_to_create)
|
||||||
logger.debug(f"Created {len(new_embeddings)} new embeddings")
|
logger.debug(f"Created {len(new_embeddings)} new embeddings")
|
||||||
num_new_embeddings += len(new_embeddings)
|
num_new_embeddings += len(new_embeddings)
|
||||||
|
|
||||||
|
@ -146,26 +146,26 @@ class TextEmbeddings(ABC):
|
||||||
dates = self.date_filter.extract_dates(embedding.raw)
|
dates = self.date_filter.extract_dates(embedding.raw)
|
||||||
for date in dates:
|
for date in dates:
|
||||||
dates_to_create.append(
|
dates_to_create.append(
|
||||||
EmbeddingsDates(
|
EntryDates(
|
||||||
date=date,
|
date=date,
|
||||||
embeddings=embedding,
|
embeddings=embedding,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
|
new_dates = EntryDates.objects.bulk_create(dates_to_create)
|
||||||
if len(new_dates) > 0:
|
if len(new_dates) > 0:
|
||||||
logger.debug(f"Created {len(new_dates)} new date entries")
|
logger.debug(f"Created {len(new_dates)} new date entries")
|
||||||
|
|
||||||
with timer("Identify hashes for removed entries", logger):
|
with timer("Identify hashes for removed entries", logger):
|
||||||
for file in hashes_by_file:
|
for file in hashes_by_file:
|
||||||
existing_entry_hashes = EmbeddingsAdapters.get_existing_entry_hashes_by_file(user, file)
|
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
|
||||||
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
|
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
|
||||||
num_deleted_embeddings += len(to_delete_entry_hashes)
|
num_deleted_embeddings += len(to_delete_entry_hashes)
|
||||||
EmbeddingsAdapters.delete_embedding_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
||||||
|
|
||||||
with timer("Identify hashes for deleting entries", logger):
|
with timer("Identify hashes for deleting entries", logger):
|
||||||
if deletion_filenames is not None:
|
if deletion_filenames is not None:
|
||||||
for file_path in deletion_filenames:
|
for file_path in deletion_filenames:
|
||||||
deleted_count = EmbeddingsAdapters.delete_embedding_by_file(user, file_path)
|
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
||||||
num_deleted_embeddings += deleted_count
|
num_deleted_embeddings += deleted_count
|
||||||
|
|
||||||
return num_new_embeddings, num_deleted_embeddings
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
|
@ -48,7 +48,7 @@ from khoj.processor.conversation.gpt4all.chat_model import extract_questions_off
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
|
|
||||||
from database import adapters
|
from database import adapters
|
||||||
from database.adapters import EmbeddingsAdapters, ConversationAdapters
|
from database.adapters import EntryAdapters, ConversationAdapters
|
||||||
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
|
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ if not state.demo:
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
def get_config_data(request: Request):
|
def get_config_data(request: Request):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
EmbeddingsAdapters.get_unique_file_types(user)
|
EntryAdapters.get_unique_file_types(user)
|
||||||
|
|
||||||
return state.config
|
return state.config
|
||||||
|
|
||||||
|
@ -145,7 +145,7 @@ if not state.demo:
|
||||||
|
|
||||||
configuration_update_metadata = {}
|
configuration_update_metadata = {}
|
||||||
|
|
||||||
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
|
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
||||||
|
|
||||||
if state.config.content_type is not None:
|
if state.config.content_type is not None:
|
||||||
configuration_update_metadata["github"] = "github" in enabled_content
|
configuration_update_metadata["github"] = "github" in enabled_content
|
||||||
|
@ -241,9 +241,9 @@ if not state.demo:
|
||||||
raise ValueError(f"Invalid content type: {content_type}")
|
raise ValueError(f"Invalid content type: {content_type}")
|
||||||
|
|
||||||
await content_object.objects.filter(user=user).adelete()
|
await content_object.objects.filter(user=user).adelete()
|
||||||
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
|
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_type)
|
||||||
|
|
||||||
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
|
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
|
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
|
||||||
|
@ -372,7 +372,7 @@ def get_config_types(
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
|
|
||||||
enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user)
|
enabled_file_types = EntryAdapters.get_unique_file_types(user)
|
||||||
|
|
||||||
configured_content_types = list(enabled_file_types)
|
configured_content_types = list(enabled_file_types)
|
||||||
|
|
||||||
|
@ -706,7 +706,7 @@ async def extract_references_and_questions(
|
||||||
if conversation_type == ConversationCommand.General:
|
if conversation_type == ConversationCommand.General:
|
||||||
return compiled_references, inferred_queries, q
|
return compiled_references, inferred_queries, q
|
||||||
|
|
||||||
if not await EmbeddingsAdapters.user_has_embeddings(user=user):
|
if not await EntryAdapters.user_has_entries(user=user):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,7 +19,7 @@ from khoj.utils.rawconfig import (
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
|
from database.adapters import EntryAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
|
||||||
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ if not state.demo:
|
||||||
@requires(["authenticated"], redirect="login_page")
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def config_page(request: Request):
|
def config_page(request: Request):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
|
enabled_content = set(EntryAdapters.get_unique_file_types(user).all())
|
||||||
default_full_config = FullConfig(
|
default_full_config = FullConfig(
|
||||||
content_type=None,
|
content_type=None,
|
||||||
search_type=None,
|
search_type=None,
|
||||||
|
|
|
@ -6,31 +6,31 @@ from typing import List, Tuple, Type, Union, Dict
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
from sentence_transformers import util
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
|
from khoj.utils.helpers import get_absolute_path, timer
|
||||||
from khoj.utils.models import BaseEncoder
|
from khoj.utils.models import BaseEncoder
|
||||||
from khoj.utils.state import SearchType
|
from khoj.utils.state import SearchType
|
||||||
from khoj.utils.rawconfig import SearchResponse, Entry
|
from khoj.utils.rawconfig import SearchResponse, Entry
|
||||||
from khoj.utils.jsonl import load_jsonl
|
from khoj.utils.jsonl import load_jsonl
|
||||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from database.adapters import EmbeddingsAdapters
|
from database.adapters import EntryAdapters
|
||||||
from database.models import KhojUser, Embeddings
|
from database.models import KhojUser, Entry as DbEntry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
search_type_to_embeddings_type = {
|
search_type_to_embeddings_type = {
|
||||||
SearchType.Org.value: Embeddings.EmbeddingsType.ORG,
|
SearchType.Org.value: DbEntry.EntryType.ORG,
|
||||||
SearchType.Markdown.value: Embeddings.EmbeddingsType.MARKDOWN,
|
SearchType.Markdown.value: DbEntry.EntryType.MARKDOWN,
|
||||||
SearchType.Plaintext.value: Embeddings.EmbeddingsType.PLAINTEXT,
|
SearchType.Plaintext.value: DbEntry.EntryType.PLAINTEXT,
|
||||||
SearchType.Pdf.value: Embeddings.EmbeddingsType.PDF,
|
SearchType.Pdf.value: DbEntry.EntryType.PDF,
|
||||||
SearchType.Github.value: Embeddings.EmbeddingsType.GITHUB,
|
SearchType.Github.value: DbEntry.EntryType.GITHUB,
|
||||||
SearchType.Notion.value: Embeddings.EmbeddingsType.NOTION,
|
SearchType.Notion.value: DbEntry.EntryType.NOTION,
|
||||||
SearchType.All.value: None,
|
SearchType.All.value: None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,7 +121,7 @@ async def query(
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
top_k = 10
|
top_k = 10
|
||||||
with timer("Search Time", logger, state.device):
|
with timer("Search Time", logger, state.device):
|
||||||
hits = EmbeddingsAdapters.search_with_embeddings(
|
hits = EntryAdapters.search_with_embeddings(
|
||||||
user=user,
|
user=user,
|
||||||
embeddings=question_embedding,
|
embeddings=question_embedding,
|
||||||
max_results=top_k,
|
max_results=top_k,
|
||||||
|
|
|
@ -17,7 +17,7 @@ from khoj.search_type import text_search, image_search
|
||||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from database.models import KhojUser
|
from database.models import KhojUser
|
||||||
from database.adapters import EmbeddingsAdapters
|
from database.adapters import EntryAdapters
|
||||||
|
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
|
@ -178,7 +178,7 @@ def test_get_configured_types_via_api(client, sample_org_data):
|
||||||
# Act
|
# Act
|
||||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||||
|
|
||||||
enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
|
enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert list(enabled_types) == ["org"]
|
assert list(enabled_types) == ["org"]
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
# System Packages
|
# System Packages
|
||||||
import logging
|
import logging
|
||||||
import locale
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -14,7 +13,7 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||||
from khoj.utils.fs_syncer import collect_files, get_org_files
|
from khoj.utils.fs_syncer import collect_files, get_org_files
|
||||||
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
|
from database.models import LocalOrgConfig, KhojUser, Entry, GithubConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -402,10 +401,10 @@ def test_text_search_setup_github(content_config: ContentConfig, default_user: K
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count()
|
embeddings = Entry.objects.filter(user=default_user, file_type="github").count()
|
||||||
assert embeddings > 1
|
assert embeddings > 1
|
||||||
|
|
||||||
|
|
||||||
def verify_embeddings(expected_count, user):
|
def verify_embeddings(expected_count, user):
|
||||||
embeddings = Embeddings.objects.filter(user=user, file_type="org").count()
|
embeddings = Entry.objects.filter(user=user, file_type="org").count()
|
||||||
assert embeddings == expected_count
|
assert embeddings == expected_count
|
||||||
|
|
Loading…
Reference in a new issue