mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Rename DbModels Embeddings, EmbeddingsAdapter to Entry, EntryAdapter
Improves readability as name has closer match to underlying constructs - Entry is any atomic item indexed by Khoj. This can be an org-mode entry, a markdown section, a PDF or Notion page etc. - Embeddings are semantic vectors generated by the search ML model that encodes for meaning contained in an entries text. - An "Entry" contains "Embeddings" vectors but also other metadata about the entry like filename etc.
This commit is contained in:
parent
54a387326c
commit
bcbee05a9e
15 changed files with 115 additions and 87 deletions
|
@ -27,7 +27,7 @@ from database.models import (
|
|||
KhojApiUser,
|
||||
NotionConfig,
|
||||
GithubConfig,
|
||||
Embeddings,
|
||||
Entry,
|
||||
GithubRepoConfig,
|
||||
Conversation,
|
||||
ConversationProcessorConfig,
|
||||
|
@ -286,54 +286,54 @@ class ConversationAdapters:
|
|||
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()
|
||||
|
||||
|
||||
class EmbeddingsAdapters:
|
||||
class EntryAdapters:
|
||||
word_filer = WordFilter()
|
||||
file_filter = FileFilter()
|
||||
date_filter = DateFilter()
|
||||
|
||||
@staticmethod
|
||||
def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool:
|
||||
return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists()
|
||||
def does_entry_exist(user: KhojUser, hashed_value: str) -> bool:
|
||||
return Entry.objects.filter(user=user, hashed_value=hashed_value).exists()
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_file(user: KhojUser, file_path: str):
|
||||
deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete()
|
||||
def delete_entry_by_file(user: KhojUser, file_path: str):
|
||||
deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete()
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def delete_all_embeddings(user: KhojUser, file_type: str):
|
||||
deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete()
|
||||
def delete_all_entries(user: KhojUser, file_type: str):
|
||||
deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete()
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
||||
return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
||||
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]):
|
||||
Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
||||
def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]):
|
||||
Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
||||
|
||||
@staticmethod
|
||||
def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date):
|
||||
return embeddings.filter(
|
||||
embeddingsdates__date__gte=start_date,
|
||||
embeddingsdates__date__lte=end_date,
|
||||
def get_entries_by_date_filter(entry: BaseManager[Entry], start_date: date, end_date: date):
|
||||
return entry.filter(
|
||||
entrydates__date__gte=start_date,
|
||||
entrydates__date__lte=end_date,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def user_has_embeddings(user: KhojUser):
|
||||
return await Embeddings.objects.filter(user=user).aexists()
|
||||
async def user_has_entries(user: KhojUser):
|
||||
return await Entry.objects.filter(user=user).aexists()
|
||||
|
||||
@staticmethod
|
||||
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
|
||||
q_filter_terms = Q()
|
||||
|
||||
explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query)
|
||||
file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query)
|
||||
date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query)
|
||||
explicit_word_terms = EntryAdapters.word_filer.get_filter_terms(query)
|
||||
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
|
||||
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
|
||||
|
||||
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
||||
return Embeddings.objects.filter(user=user)
|
||||
return Entry.objects.filter(user=user)
|
||||
|
||||
for term in explicit_word_terms:
|
||||
if term.startswith("+"):
|
||||
|
@ -354,32 +354,32 @@ class EmbeddingsAdapters:
|
|||
if min_date is not None:
|
||||
# Convert the min_date timestamp to yyyy-mm-dd format
|
||||
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
|
||||
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
|
||||
q_filter_terms &= Q(entry_dates__date__gte=formatted_min_date)
|
||||
if max_date is not None:
|
||||
# Convert the max_date timestamp to yyyy-mm-dd format
|
||||
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
|
||||
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
|
||||
q_filter_terms &= Q(entry_dates__date__lte=formatted_max_date)
|
||||
|
||||
relevant_embeddings = Embeddings.objects.filter(user=user).filter(
|
||||
relevant_entries = Entry.objects.filter(user=user).filter(
|
||||
q_filter_terms,
|
||||
)
|
||||
if file_type_filter:
|
||||
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
|
||||
return relevant_embeddings
|
||||
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
||||
return relevant_entries
|
||||
|
||||
@staticmethod
|
||||
def search_with_embeddings(
|
||||
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
|
||||
):
|
||||
relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter)
|
||||
relevant_embeddings = relevant_embeddings.filter(user=user).annotate(
|
||||
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
|
||||
relevant_entries = relevant_entries.filter(user=user).annotate(
|
||||
distance=CosineDistance("embeddings", embeddings)
|
||||
)
|
||||
if file_type_filter:
|
||||
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
|
||||
relevant_embeddings = relevant_embeddings.order_by("distance")
|
||||
return relevant_embeddings[:max_results]
|
||||
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
||||
relevant_entries = relevant_entries.order_by("distance")
|
||||
return relevant_entries[:max_results]
|
||||
|
||||
@staticmethod
|
||||
def get_unique_file_types(user: KhojUser):
|
||||
return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
||||
return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Generated by Django 4.2.5 on 2023-10-26 23:52
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0009_khojapiuser"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RenameModel(
|
||||
old_name="Embeddings",
|
||||
new_name="Entry",
|
||||
),
|
||||
migrations.RenameModel(
|
||||
old_name="EmbeddingsDates",
|
||||
new_name="EntryDates",
|
||||
),
|
||||
migrations.RenameField(
|
||||
model_name="entrydates",
|
||||
old_name="embeddings",
|
||||
new_name="entry",
|
||||
),
|
||||
migrations.RenameIndex(
|
||||
model_name="entrydates",
|
||||
new_name="database_en_date_8d823c_idx",
|
||||
old_name="database_em_date_a1ba47_idx",
|
||||
),
|
||||
]
|
|
@ -114,8 +114,8 @@ class Conversation(BaseModel):
|
|||
conversation_log = models.JSONField(default=dict)
|
||||
|
||||
|
||||
class Embeddings(BaseModel):
|
||||
class EmbeddingsType(models.TextChoices):
|
||||
class Entry(BaseModel):
|
||||
class EntryType(models.TextChoices):
|
||||
IMAGE = "image"
|
||||
PDF = "pdf"
|
||||
PLAINTEXT = "plaintext"
|
||||
|
@ -130,7 +130,7 @@ class Embeddings(BaseModel):
|
|||
raw = models.TextField()
|
||||
compiled = models.TextField()
|
||||
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
|
||||
file_type = models.CharField(max_length=30, choices=EmbeddingsType.choices, default=EmbeddingsType.PLAINTEXT)
|
||||
file_type = models.CharField(max_length=30, choices=EntryType.choices, default=EntryType.PLAINTEXT)
|
||||
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
url = models.URLField(max_length=400, default=None, null=True, blank=True)
|
||||
|
@ -138,9 +138,9 @@ class Embeddings(BaseModel):
|
|||
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
||||
|
||||
|
||||
class EmbeddingsDates(BaseModel):
|
||||
class EntryDates(BaseModel):
|
||||
date = models.DateField()
|
||||
embeddings = models.ForeignKey(Embeddings, on_delete=models.CASCADE, related_name="embeddings_dates")
|
||||
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
|
||||
|
||||
class Meta:
|
||||
indexes = [
|
||||
|
|
|
@ -13,8 +13,7 @@ from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
|
|||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from database.models import Embeddings, GithubConfig, KhojUser
|
||||
from database.models import Entry as DbEntry, GithubConfig, KhojUser
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -103,7 +102,7 @@ class GithubToJsonl(TextEmbeddings):
|
|||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries, Embeddings.EmbeddingsType.GITHUB, key="compiled", logger=logger, user=user
|
||||
current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
|
|
@ -10,7 +10,7 @@ from khoj.processor.text_to_jsonl import TextEmbeddings
|
|||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from database.models import Embeddings, KhojUser
|
||||
from database.models import Entry as DbEntry, KhojUser
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -46,7 +46,7 @@ class MarkdownToJsonl(TextEmbeddings):
|
|||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
Embeddings.EmbeddingsType.MARKDOWN,
|
||||
DbEntry.EntryType.MARKDOWN,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
|
|
|
@ -10,7 +10,7 @@ from khoj.utils.helpers import timer
|
|||
from khoj.utils.rawconfig import Entry, NotionContentConfig
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from database.models import Embeddings, KhojUser, NotionConfig
|
||||
from database.models import Entry as DbEntry, KhojUser, NotionConfig
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
@ -250,7 +250,7 @@ class NotionToJsonl(TextEmbeddings):
|
|||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries, Embeddings.EmbeddingsType.NOTION, key="compiled", logger=logger, user=user
|
||||
current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
|
|
@ -9,7 +9,7 @@ from khoj.processor.text_to_jsonl import TextEmbeddings
|
|||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.utils import state
|
||||
from database.models import Embeddings, KhojUser
|
||||
from database.models import Entry as DbEntry, KhojUser
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -47,7 +47,7 @@ class OrgToJsonl(TextEmbeddings):
|
|||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
Embeddings.EmbeddingsType.ORG,
|
||||
DbEntry.EntryType.ORG,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
|
|
|
@ -11,7 +11,7 @@ from langchain.document_loaders import PyMuPDFLoader
|
|||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from database.models import Embeddings, KhojUser
|
||||
from database.models import Entry as DbEntry, KhojUser
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -45,7 +45,7 @@ class PdfToJsonl(TextEmbeddings):
|
|||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
Embeddings.EmbeddingsType.PDF,
|
||||
DbEntry.EntryType.PDF,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
|
|
|
@ -9,7 +9,7 @@ from bs4 import BeautifulSoup
|
|||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from database.models import Embeddings, KhojUser
|
||||
from database.models import Entry as DbEntry, KhojUser
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -55,7 +55,7 @@ class PlaintextToJsonl(TextEmbeddings):
|
|||
with timer("Identify new or updated entries", logger):
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
Embeddings.EmbeddingsType.PLAINTEXT,
|
||||
DbEntry.EntryType.PLAINTEXT,
|
||||
key="compiled",
|
||||
logger=logger,
|
||||
deletion_filenames=deletion_file_names,
|
||||
|
|
|
@ -12,8 +12,8 @@ from khoj.utils.helpers import timer, batcher
|
|||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.processor.embeddings import EmbeddingsModel
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from database.models import KhojUser, Embeddings, EmbeddingsDates
|
||||
from database.adapters import EmbeddingsAdapters
|
||||
from database.models import KhojUser, Entry as DbEntry, EntryDates
|
||||
from database.adapters import EntryAdapters
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -94,14 +94,14 @@ class TextEmbeddings(ABC):
|
|||
with timer("Preparing dataset for regeneration", logger):
|
||||
if regenerate:
|
||||
logger.debug(f"Deleting all embeddings for file type {file_type}")
|
||||
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
|
||||
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type)
|
||||
|
||||
num_new_embeddings = 0
|
||||
with timer("Identify hashes for adding new entries", logger):
|
||||
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
|
||||
hashes_for_file = hashes_by_file[file]
|
||||
hashes_to_process = set()
|
||||
existing_entries = Embeddings.objects.filter(
|
||||
existing_entries = DbEntry.objects.filter(
|
||||
user=user, hashed_value__in=hashes_for_file, file_type=file_type
|
||||
)
|
||||
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
||||
|
@ -124,7 +124,7 @@ class TextEmbeddings(ABC):
|
|||
for entry_hash, embedding in entry_batch:
|
||||
entry = hash_to_current_entries[entry_hash]
|
||||
batch_embeddings_to_create.append(
|
||||
Embeddings(
|
||||
DbEntry(
|
||||
user=user,
|
||||
embeddings=embedding,
|
||||
raw=entry.raw,
|
||||
|
@ -136,7 +136,7 @@ class TextEmbeddings(ABC):
|
|||
corpus_id=entry.corpus_id,
|
||||
)
|
||||
)
|
||||
new_embeddings = Embeddings.objects.bulk_create(batch_embeddings_to_create)
|
||||
new_embeddings = DbEntry.objects.bulk_create(batch_embeddings_to_create)
|
||||
logger.debug(f"Created {len(new_embeddings)} new embeddings")
|
||||
num_new_embeddings += len(new_embeddings)
|
||||
|
||||
|
@ -146,26 +146,26 @@ class TextEmbeddings(ABC):
|
|||
dates = self.date_filter.extract_dates(embedding.raw)
|
||||
for date in dates:
|
||||
dates_to_create.append(
|
||||
EmbeddingsDates(
|
||||
EntryDates(
|
||||
date=date,
|
||||
embeddings=embedding,
|
||||
)
|
||||
)
|
||||
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
|
||||
new_dates = EntryDates.objects.bulk_create(dates_to_create)
|
||||
if len(new_dates) > 0:
|
||||
logger.debug(f"Created {len(new_dates)} new date entries")
|
||||
|
||||
with timer("Identify hashes for removed entries", logger):
|
||||
for file in hashes_by_file:
|
||||
existing_entry_hashes = EmbeddingsAdapters.get_existing_entry_hashes_by_file(user, file)
|
||||
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
|
||||
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
|
||||
num_deleted_embeddings += len(to_delete_entry_hashes)
|
||||
EmbeddingsAdapters.delete_embedding_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
||||
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
||||
|
||||
with timer("Identify hashes for deleting entries", logger):
|
||||
if deletion_filenames is not None:
|
||||
for file_path in deletion_filenames:
|
||||
deleted_count = EmbeddingsAdapters.delete_embedding_by_file(user, file_path)
|
||||
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
||||
num_deleted_embeddings += deleted_count
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
|
|
@ -48,7 +48,7 @@ from khoj.processor.conversation.gpt4all.chat_model import extract_questions_off
|
|||
from fastapi.requests import Request
|
||||
|
||||
from database import adapters
|
||||
from database.adapters import EmbeddingsAdapters, ConversationAdapters
|
||||
from database.adapters import EntryAdapters, ConversationAdapters
|
||||
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
|
||||
|
||||
|
||||
|
@ -129,7 +129,7 @@ if not state.demo:
|
|||
@requires(["authenticated"])
|
||||
def get_config_data(request: Request):
|
||||
user = request.user.object
|
||||
EmbeddingsAdapters.get_unique_file_types(user)
|
||||
EntryAdapters.get_unique_file_types(user)
|
||||
|
||||
return state.config
|
||||
|
||||
|
@ -145,7 +145,7 @@ if not state.demo:
|
|||
|
||||
configuration_update_metadata = {}
|
||||
|
||||
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
|
||||
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
||||
|
||||
if state.config.content_type is not None:
|
||||
configuration_update_metadata["github"] = "github" in enabled_content
|
||||
|
@ -241,9 +241,9 @@ if not state.demo:
|
|||
raise ValueError(f"Invalid content type: {content_type}")
|
||||
|
||||
await content_object.objects.filter(user=user).adelete()
|
||||
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
|
||||
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_type)
|
||||
|
||||
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
|
||||
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
||||
return {"status": "ok"}
|
||||
|
||||
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
|
||||
|
@ -372,7 +372,7 @@ def get_config_types(
|
|||
):
|
||||
user = request.user.object
|
||||
|
||||
enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user)
|
||||
enabled_file_types = EntryAdapters.get_unique_file_types(user)
|
||||
|
||||
configured_content_types = list(enabled_file_types)
|
||||
|
||||
|
@ -706,7 +706,7 @@ async def extract_references_and_questions(
|
|||
if conversation_type == ConversationCommand.General:
|
||||
return compiled_references, inferred_queries, q
|
||||
|
||||
if not await EmbeddingsAdapters.user_has_embeddings(user=user):
|
||||
if not await EntryAdapters.user_has_entries(user=user):
|
||||
logger.warning(
|
||||
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
||||
)
|
||||
|
|
|
@ -19,7 +19,7 @@ from khoj.utils.rawconfig import (
|
|||
|
||||
# Internal Packages
|
||||
from khoj.utils import constants, state
|
||||
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
|
||||
from database.adapters import EntryAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
|
||||
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
||||
|
||||
|
||||
|
@ -84,7 +84,7 @@ if not state.demo:
|
|||
@requires(["authenticated"], redirect="login_page")
|
||||
def config_page(request: Request):
|
||||
user = request.user.object
|
||||
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
|
||||
enabled_content = set(EntryAdapters.get_unique_file_types(user).all())
|
||||
default_full_config = FullConfig(
|
||||
content_type=None,
|
||||
search_type=None,
|
||||
|
|
|
@ -6,31 +6,31 @@ from typing import List, Tuple, Type, Union, Dict
|
|||
|
||||
# External Packages
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
from sentence_transformers import util
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
|
||||
from khoj.utils.helpers import get_absolute_path, timer
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.state import SearchType
|
||||
from khoj.utils.rawconfig import SearchResponse, Entry
|
||||
from khoj.utils.jsonl import load_jsonl
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from database.adapters import EmbeddingsAdapters
|
||||
from database.models import KhojUser, Embeddings
|
||||
from database.adapters import EntryAdapters
|
||||
from database.models import KhojUser, Entry as DbEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
search_type_to_embeddings_type = {
|
||||
SearchType.Org.value: Embeddings.EmbeddingsType.ORG,
|
||||
SearchType.Markdown.value: Embeddings.EmbeddingsType.MARKDOWN,
|
||||
SearchType.Plaintext.value: Embeddings.EmbeddingsType.PLAINTEXT,
|
||||
SearchType.Pdf.value: Embeddings.EmbeddingsType.PDF,
|
||||
SearchType.Github.value: Embeddings.EmbeddingsType.GITHUB,
|
||||
SearchType.Notion.value: Embeddings.EmbeddingsType.NOTION,
|
||||
SearchType.Org.value: DbEntry.EntryType.ORG,
|
||||
SearchType.Markdown.value: DbEntry.EntryType.MARKDOWN,
|
||||
SearchType.Plaintext.value: DbEntry.EntryType.PLAINTEXT,
|
||||
SearchType.Pdf.value: DbEntry.EntryType.PDF,
|
||||
SearchType.Github.value: DbEntry.EntryType.GITHUB,
|
||||
SearchType.Notion.value: DbEntry.EntryType.NOTION,
|
||||
SearchType.All.value: None,
|
||||
}
|
||||
|
||||
|
@ -121,7 +121,7 @@ async def query(
|
|||
# Find relevant entries for the query
|
||||
top_k = 10
|
||||
with timer("Search Time", logger, state.device):
|
||||
hits = EmbeddingsAdapters.search_with_embeddings(
|
||||
hits = EntryAdapters.search_with_embeddings(
|
||||
user=user,
|
||||
embeddings=question_embedding,
|
||||
max_results=top_k,
|
||||
|
|
|
@ -17,7 +17,7 @@ from khoj.search_type import text_search, image_search
|
|||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from database.models import KhojUser
|
||||
from database.adapters import EmbeddingsAdapters
|
||||
from database.adapters import EntryAdapters
|
||||
|
||||
|
||||
# Test
|
||||
|
@ -178,7 +178,7 @@ def test_get_configured_types_via_api(client, sample_org_data):
|
|||
# Act
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||
|
||||
enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
|
||||
enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
|
||||
|
||||
# Assert
|
||||
assert list(enabled_types) == ["org"]
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# System Packages
|
||||
import logging
|
||||
import locale
|
||||
from pathlib import Path
|
||||
import os
|
||||
import asyncio
|
||||
|
@ -14,7 +13,7 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
|||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||
from khoj.utils.fs_syncer import collect_files, get_org_files
|
||||
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
|
||||
from database.models import LocalOrgConfig, KhojUser, Entry, GithubConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -402,10 +401,10 @@ def test_text_search_setup_github(content_config: ContentConfig, default_user: K
|
|||
)
|
||||
|
||||
# Assert
|
||||
embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count()
|
||||
embeddings = Entry.objects.filter(user=default_user, file_type="github").count()
|
||||
assert embeddings > 1
|
||||
|
||||
|
||||
def verify_embeddings(expected_count, user):
|
||||
embeddings = Embeddings.objects.filter(user=user, file_type="org").count()
|
||||
embeddings = Entry.objects.filter(user=user, file_type="org").count()
|
||||
assert embeddings == expected_count
|
||||
|
|
Loading…
Reference in a new issue