diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 362398d8..080e73d7 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -27,7 +27,7 @@ from database.models import ( KhojApiUser, NotionConfig, GithubConfig, - Embeddings, + Entry, GithubRepoConfig, Conversation, ConversationProcessorConfig, @@ -286,54 +286,54 @@ class ConversationAdapters: return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst() -class EmbeddingsAdapters: +class EntryAdapters: word_filer = WordFilter() file_filter = FileFilter() date_filter = DateFilter() @staticmethod - def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool: - return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists() + def does_entry_exist(user: KhojUser, hashed_value: str) -> bool: + return Entry.objects.filter(user=user, hashed_value=hashed_value).exists() @staticmethod - def delete_embedding_by_file(user: KhojUser, file_path: str): - deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete() + def delete_entry_by_file(user: KhojUser, file_path: str): + deleted_count, _ = Entry.objects.filter(user=user, file_path=file_path).delete() return deleted_count @staticmethod - def delete_all_embeddings(user: KhojUser, file_type: str): - deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete() + def delete_all_entries(user: KhojUser, file_type: str): + deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete() return deleted_count @staticmethod def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): - return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) + return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) @staticmethod - def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]): - Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete() + def delete_entry_by_hash(user: KhojUser, hashed_values: List[str]): + Entry.objects.filter(user=user, hashed_value__in=hashed_values).delete() @staticmethod - def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date): - return embeddings.filter( - embeddingsdates__date__gte=start_date, - embeddingsdates__date__lte=end_date, + def get_entries_by_date_filter(entry: BaseManager[Entry], start_date: date, end_date: date): + return entry.filter( + entrydates__date__gte=start_date, + entrydates__date__lte=end_date, ) @staticmethod - async def user_has_embeddings(user: KhojUser): - return await Embeddings.objects.filter(user=user).aexists() + async def user_has_entries(user: KhojUser): + return await Entry.objects.filter(user=user).aexists() @staticmethod def apply_filters(user: KhojUser, query: str, file_type_filter: str = None): q_filter_terms = Q() - explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query) - file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query) - date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query) + explicit_word_terms = EntryAdapters.word_filer.get_filter_terms(query) + file_filters = EntryAdapters.file_filter.get_filter_terms(query) + date_filters = EntryAdapters.date_filter.get_query_date_range(query) if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0: - return Embeddings.objects.filter(user=user) + return Entry.objects.filter(user=user) for term in explicit_word_terms: if term.startswith("+"): @@ -354,32 +354,32 @@ class EmbeddingsAdapters: if min_date is not None: # Convert the min_date timestamp to yyyy-mm-dd format formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d") - q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date) + q_filter_terms &= Q(entry_dates__date__gte=formatted_min_date) if max_date is not None: # Convert the max_date timestamp to yyyy-mm-dd format formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d") - q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date) + q_filter_terms &= Q(entry_dates__date__lte=formatted_max_date) - relevant_embeddings = Embeddings.objects.filter(user=user).filter( + relevant_entries = Entry.objects.filter(user=user).filter( q_filter_terms, ) if file_type_filter: - relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter) - return relevant_embeddings + relevant_entries = relevant_entries.filter(file_type=file_type_filter) + return relevant_entries @staticmethod def search_with_embeddings( user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None ): - relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter) - relevant_embeddings = relevant_embeddings.filter(user=user).annotate( + relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter) + relevant_entries = relevant_entries.filter(user=user).annotate( distance=CosineDistance("embeddings", embeddings) ) if file_type_filter: - relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter) - relevant_embeddings = relevant_embeddings.order_by("distance") - return relevant_embeddings[:max_results] + relevant_entries = relevant_entries.filter(file_type=file_type_filter) + relevant_entries = relevant_entries.order_by("distance") + return relevant_entries[:max_results] @staticmethod def get_unique_file_types(user: KhojUser): - return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct() + return Entry.objects.filter(user=user).values_list("file_type", flat=True).distinct() diff --git a/src/database/migrations/0010_rename_embeddings_entry_and_more.py b/src/database/migrations/0010_rename_embeddings_entry_and_more.py new file mode 100644 index 00000000..f86b2caa --- /dev/null +++ b/src/database/migrations/0010_rename_embeddings_entry_and_more.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.5 on 2023-10-26 23:52 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0009_khojapiuser"), + ] + + operations = [ + migrations.RenameModel( + old_name="Embeddings", + new_name="Entry", + ), + migrations.RenameModel( + old_name="EmbeddingsDates", + new_name="EntryDates", + ), + migrations.RenameField( + model_name="entrydates", + old_name="embeddings", + new_name="entry", + ), + migrations.RenameIndex( + model_name="entrydates", + new_name="database_en_date_8d823c_idx", + old_name="database_em_date_a1ba47_idx", + ), + ] diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index 7c9c3822..fe020601 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -114,8 +114,8 @@ class Conversation(BaseModel): conversation_log = models.JSONField(default=dict) -class Embeddings(BaseModel): - class EmbeddingsType(models.TextChoices): +class Entry(BaseModel): + class EntryType(models.TextChoices): IMAGE = "image" PDF = "pdf" PLAINTEXT = "plaintext" @@ -130,7 +130,7 @@ class Embeddings(BaseModel): raw = models.TextField() compiled = models.TextField() heading = models.CharField(max_length=1000, default=None, null=True, blank=True) - file_type = models.CharField(max_length=30, choices=EmbeddingsType.choices, default=EmbeddingsType.PLAINTEXT) + file_type = models.CharField(max_length=30, choices=EntryType.choices, default=EntryType.PLAINTEXT) file_path = models.CharField(max_length=400, default=None, null=True, blank=True) file_name = models.CharField(max_length=400, default=None, null=True, blank=True) url = models.URLField(max_length=400, default=None, null=True, blank=True) @@ -138,9 +138,9 @@ class Embeddings(BaseModel): corpus_id = models.UUIDField(default=uuid.uuid4, editable=False) -class EmbeddingsDates(BaseModel): +class EntryDates(BaseModel): date = models.DateField() - embeddings = models.ForeignKey(Embeddings, on_delete=models.CASCADE, related_name="embeddings_dates") + entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates") class Meta: indexes = [ diff --git a/src/khoj/processor/github/github_to_jsonl.py b/src/khoj/processor/github/github_to_jsonl.py index 8feb6a31..a548ae1b 100644 --- a/src/khoj/processor/github/github_to_jsonl.py +++ b/src/khoj/processor/github/github_to_jsonl.py @@ -13,8 +13,7 @@ from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.text_to_jsonl import TextEmbeddings -from khoj.utils.rawconfig import Entry -from database.models import Embeddings, GithubConfig, KhojUser +from database.models import Entry as DbEntry, GithubConfig, KhojUser logger = logging.getLogger(__name__) @@ -103,7 +102,7 @@ class GithubToJsonl(TextEmbeddings): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( - current_entries, Embeddings.EmbeddingsType.GITHUB, key="compiled", logger=logger, user=user + current_entries, DbEntry.EntryType.GITHUB, key="compiled", logger=logger, user=user ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/markdown/markdown_to_jsonl.py b/src/khoj/processor/markdown/markdown_to_jsonl.py index 17136b00..921f2213 100644 --- a/src/khoj/processor/markdown/markdown_to_jsonl.py +++ b/src/khoj/processor/markdown/markdown_to_jsonl.py @@ -10,7 +10,7 @@ from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.helpers import timer from khoj.utils.constants import empty_escape_sequences from khoj.utils.rawconfig import Entry -from database.models import Embeddings, KhojUser +from database.models import Entry as DbEntry, KhojUser logger = logging.getLogger(__name__) @@ -46,7 +46,7 @@ class MarkdownToJsonl(TextEmbeddings): with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, - Embeddings.EmbeddingsType.MARKDOWN, + DbEntry.EntryType.MARKDOWN, "compiled", logger, deletion_file_names, diff --git a/src/khoj/processor/notion/notion_to_jsonl.py b/src/khoj/processor/notion/notion_to_jsonl.py index 0081350a..15c21b23 100644 --- a/src/khoj/processor/notion/notion_to_jsonl.py +++ b/src/khoj/processor/notion/notion_to_jsonl.py @@ -10,7 +10,7 @@ from khoj.utils.helpers import timer from khoj.utils.rawconfig import Entry, NotionContentConfig from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.rawconfig import Entry -from database.models import Embeddings, KhojUser, NotionConfig +from database.models import Entry as DbEntry, KhojUser, NotionConfig from enum import Enum @@ -250,7 +250,7 @@ class NotionToJsonl(TextEmbeddings): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( - current_entries, Embeddings.EmbeddingsType.NOTION, key="compiled", logger=logger, user=user + current_entries, DbEntry.EntryType.NOTION, key="compiled", logger=logger, user=user ) return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/org_mode/org_to_jsonl.py b/src/khoj/processor/org_mode/org_to_jsonl.py index 90fdc029..9bf85660 100644 --- a/src/khoj/processor/org_mode/org_to_jsonl.py +++ b/src/khoj/processor/org_mode/org_to_jsonl.py @@ -9,7 +9,7 @@ from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.helpers import timer from khoj.utils.rawconfig import Entry from khoj.utils import state -from database.models import Embeddings, KhojUser +from database.models import Entry as DbEntry, KhojUser logger = logging.getLogger(__name__) @@ -47,7 +47,7 @@ class OrgToJsonl(TextEmbeddings): with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, - Embeddings.EmbeddingsType.ORG, + DbEntry.EntryType.ORG, "compiled", logger, deletion_file_names, diff --git a/src/khoj/processor/pdf/pdf_to_jsonl.py b/src/khoj/processor/pdf/pdf_to_jsonl.py index 3a712c68..feed12d7 100644 --- a/src/khoj/processor/pdf/pdf_to_jsonl.py +++ b/src/khoj/processor/pdf/pdf_to_jsonl.py @@ -11,7 +11,7 @@ from langchain.document_loaders import PyMuPDFLoader from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.helpers import timer from khoj.utils.rawconfig import Entry -from database.models import Embeddings, KhojUser +from database.models import Entry as DbEntry, KhojUser logger = logging.getLogger(__name__) @@ -45,7 +45,7 @@ class PdfToJsonl(TextEmbeddings): with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, - Embeddings.EmbeddingsType.PDF, + DbEntry.EntryType.PDF, "compiled", logger, deletion_file_names, diff --git a/src/khoj/processor/plaintext/plaintext_to_jsonl.py b/src/khoj/processor/plaintext/plaintext_to_jsonl.py index 086808b7..a657ff2f 100644 --- a/src/khoj/processor/plaintext/plaintext_to_jsonl.py +++ b/src/khoj/processor/plaintext/plaintext_to_jsonl.py @@ -9,7 +9,7 @@ from bs4 import BeautifulSoup from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.helpers import timer from khoj.utils.rawconfig import Entry -from database.models import Embeddings, KhojUser +from database.models import Entry as DbEntry, KhojUser logger = logging.getLogger(__name__) @@ -55,7 +55,7 @@ class PlaintextToJsonl(TextEmbeddings): with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( current_entries, - Embeddings.EmbeddingsType.PLAINTEXT, + DbEntry.EntryType.PLAINTEXT, key="compiled", logger=logger, deletion_filenames=deletion_file_names, diff --git a/src/khoj/processor/text_to_jsonl.py b/src/khoj/processor/text_to_jsonl.py index 831a032f..3aa6a5b1 100644 --- a/src/khoj/processor/text_to_jsonl.py +++ b/src/khoj/processor/text_to_jsonl.py @@ -12,8 +12,8 @@ from khoj.utils.helpers import timer, batcher from khoj.utils.rawconfig import Entry from khoj.processor.embeddings import EmbeddingsModel from khoj.search_filter.date_filter import DateFilter -from database.models import KhojUser, Embeddings, EmbeddingsDates -from database.adapters import EmbeddingsAdapters +from database.models import KhojUser, Entry as DbEntry, EntryDates +from database.adapters import EntryAdapters logger = logging.getLogger(__name__) @@ -94,14 +94,14 @@ class TextEmbeddings(ABC): with timer("Preparing dataset for regeneration", logger): if regenerate: logger.debug(f"Deleting all embeddings for file type {file_type}") - num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type) + num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type) num_new_embeddings = 0 with timer("Identify hashes for adding new entries", logger): for file in tqdm(hashes_by_file, desc="Processing file with hashed values"): hashes_for_file = hashes_by_file[file] hashes_to_process = set() - existing_entries = Embeddings.objects.filter( + existing_entries = DbEntry.objects.filter( user=user, hashed_value__in=hashes_for_file, file_type=file_type ) existing_entry_hashes = set([entry.hashed_value for entry in existing_entries]) @@ -124,7 +124,7 @@ class TextEmbeddings(ABC): for entry_hash, embedding in entry_batch: entry = hash_to_current_entries[entry_hash] batch_embeddings_to_create.append( - Embeddings( + DbEntry( user=user, embeddings=embedding, raw=entry.raw, @@ -136,7 +136,7 @@ class TextEmbeddings(ABC): corpus_id=entry.corpus_id, ) ) - new_embeddings = Embeddings.objects.bulk_create(batch_embeddings_to_create) + new_embeddings = DbEntry.objects.bulk_create(batch_embeddings_to_create) logger.debug(f"Created {len(new_embeddings)} new embeddings") num_new_embeddings += len(new_embeddings) @@ -146,26 +146,26 @@ class TextEmbeddings(ABC): dates = self.date_filter.extract_dates(embedding.raw) for date in dates: dates_to_create.append( - EmbeddingsDates( + EntryDates( date=date, embeddings=embedding, ) ) - new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create) + new_dates = EntryDates.objects.bulk_create(dates_to_create) if len(new_dates) > 0: logger.debug(f"Created {len(new_dates)} new date entries") with timer("Identify hashes for removed entries", logger): for file in hashes_by_file: - existing_entry_hashes = EmbeddingsAdapters.get_existing_entry_hashes_by_file(user, file) + existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file) to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file] num_deleted_embeddings += len(to_delete_entry_hashes) - EmbeddingsAdapters.delete_embedding_by_hash(user, hashed_values=list(to_delete_entry_hashes)) + EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes)) with timer("Identify hashes for deleting entries", logger): if deletion_filenames is not None: for file_path in deletion_filenames: - deleted_count = EmbeddingsAdapters.delete_embedding_by_file(user, file_path) + deleted_count = EntryAdapters.delete_entry_by_file(user, file_path) num_deleted_embeddings += deleted_count return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 2607120d..8f6af0bf 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -48,7 +48,7 @@ from khoj.processor.conversation.gpt4all.chat_model import extract_questions_off from fastapi.requests import Request from database import adapters -from database.adapters import EmbeddingsAdapters, ConversationAdapters +from database.adapters import EntryAdapters, ConversationAdapters from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser @@ -129,7 +129,7 @@ if not state.demo: @requires(["authenticated"]) def get_config_data(request: Request): user = request.user.object - EmbeddingsAdapters.get_unique_file_types(user) + EntryAdapters.get_unique_file_types(user) return state.config @@ -145,7 +145,7 @@ if not state.demo: configuration_update_metadata = {} - enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user) + enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user) if state.config.content_type is not None: configuration_update_metadata["github"] = "github" in enabled_content @@ -241,9 +241,9 @@ if not state.demo: raise ValueError(f"Invalid content type: {content_type}") await content_object.objects.filter(user=user).adelete() - await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type) + await sync_to_async(EntryAdapters.delete_all_entries)(user, content_type) - enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user) + enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user) return {"status": "ok"} @api.post("/delete/config/data/processor/conversation/openai", status_code=200) @@ -372,7 +372,7 @@ def get_config_types( ): user = request.user.object - enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user) + enabled_file_types = EntryAdapters.get_unique_file_types(user) configured_content_types = list(enabled_file_types) @@ -706,7 +706,7 @@ async def extract_references_and_questions( if conversation_type == ConversationCommand.General: return compiled_references, inferred_queries, q - if not await EmbeddingsAdapters.user_has_embeddings(user=user): + if not await EntryAdapters.user_has_entries(user=user): logger.warning( "No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes." ) diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 06f43430..ef0abe18 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -19,7 +19,7 @@ from khoj.utils.rawconfig import ( # Internal Packages from khoj.utils import constants, state -from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters +from database.adapters import EntryAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig @@ -84,7 +84,7 @@ if not state.demo: @requires(["authenticated"], redirect="login_page") def config_page(request: Request): user = request.user.object - enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all()) + enabled_content = set(EntryAdapters.get_unique_file_types(user).all()) default_full_config = FullConfig( content_type=None, search_type=None, diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index dc6593f5..db3b313c 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -6,31 +6,31 @@ from typing import List, Tuple, Type, Union, Dict # External Packages import torch -from sentence_transformers import SentenceTransformer, CrossEncoder, util +from sentence_transformers import util from asgiref.sync import sync_to_async # Internal Packages from khoj.utils import state -from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer +from khoj.utils.helpers import get_absolute_path, timer from khoj.utils.models import BaseEncoder from khoj.utils.state import SearchType from khoj.utils.rawconfig import SearchResponse, Entry from khoj.utils.jsonl import load_jsonl from khoj.processor.text_to_jsonl import TextEmbeddings -from database.adapters import EmbeddingsAdapters -from database.models import KhojUser, Embeddings +from database.adapters import EntryAdapters +from database.models import KhojUser, Entry as DbEntry logger = logging.getLogger(__name__) search_type_to_embeddings_type = { - SearchType.Org.value: Embeddings.EmbeddingsType.ORG, - SearchType.Markdown.value: Embeddings.EmbeddingsType.MARKDOWN, - SearchType.Plaintext.value: Embeddings.EmbeddingsType.PLAINTEXT, - SearchType.Pdf.value: Embeddings.EmbeddingsType.PDF, - SearchType.Github.value: Embeddings.EmbeddingsType.GITHUB, - SearchType.Notion.value: Embeddings.EmbeddingsType.NOTION, + SearchType.Org.value: DbEntry.EntryType.ORG, + SearchType.Markdown.value: DbEntry.EntryType.MARKDOWN, + SearchType.Plaintext.value: DbEntry.EntryType.PLAINTEXT, + SearchType.Pdf.value: DbEntry.EntryType.PDF, + SearchType.Github.value: DbEntry.EntryType.GITHUB, + SearchType.Notion.value: DbEntry.EntryType.NOTION, SearchType.All.value: None, } @@ -121,7 +121,7 @@ async def query( # Find relevant entries for the query top_k = 10 with timer("Search Time", logger, state.device): - hits = EmbeddingsAdapters.search_with_embeddings( + hits = EntryAdapters.search_with_embeddings( user=user, embeddings=question_embedding, max_results=top_k, diff --git a/tests/test_client.py b/tests/test_client.py index 6818c2ba..a1013017 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -17,7 +17,7 @@ from khoj.search_type import text_search, image_search from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from database.models import KhojUser -from database.adapters import EmbeddingsAdapters +from database.adapters import EntryAdapters # Test @@ -178,7 +178,7 @@ def test_get_configured_types_via_api(client, sample_org_data): # Act text_search.setup(OrgToJsonl, sample_org_data, regenerate=False) - enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True) + enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True) # Assert assert list(enabled_types) == ["org"] diff --git a/tests/test_text_search.py b/tests/test_text_search.py index ae7f0c20..db26ea7b 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -1,6 +1,5 @@ # System Packages import logging -import locale from pathlib import Path import os import asyncio @@ -14,7 +13,7 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.github.github_to_jsonl import GithubToJsonl from khoj.utils.fs_syncer import collect_files, get_org_files -from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig +from database.models import LocalOrgConfig, KhojUser, Entry, GithubConfig logger = logging.getLogger(__name__) @@ -402,10 +401,10 @@ def test_text_search_setup_github(content_config: ContentConfig, default_user: K ) # Assert - embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count() + embeddings = Entry.objects.filter(user=default_user, file_type="github").count() assert embeddings > 1 def verify_embeddings(expected_count, user): - embeddings = Embeddings.objects.filter(user=user, file_type="org").count() + embeddings = Entry.objects.filter(user=user, file_type="org").count() assert embeddings == expected_count