diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 00000000..a571e8a1 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,48 @@ +name: pre-commit + +on: + pull_request: + paths: + - src/** + - tests/** + - config/** + - pyproject.toml + - .pre-commit-config.yml + - .github/workflows/test.yml + push: + branches: + - master + paths: + - src/khoj/** + - tests/** + - config/** + - pyproject.toml + - .pre-commit-config.yml + - .github/workflows/test.yml + +jobs: + test: + name: Run Tests + runs-on: ubuntu-latest + strategy: + fail-fast: false + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.11 + + - name: ⏬️ Install Dependencies + run: | + sudo apt update && sudo apt install -y libegl1 + python -m pip install --upgrade pip + + - name: ⬇️ Install Application + run: pip install --upgrade .[dev] + + - name: 🌡️ Validate Application + run: pre-commit run --hook-stage manual --all diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d8aa9be8..84fbb1aa 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,10 +2,8 @@ name: test on: pull_request: - branches: - - 'master' paths: - - src/khoj/** + - src/** - tests/** - config/** - pyproject.toml @@ -13,7 +11,7 @@ on: - .github/workflows/test.yml push: branches: - - 'master' + - master paths: - src/khoj/** - tests/** @@ -26,6 +24,7 @@ jobs: test: name: Run Tests runs-on: ubuntu-latest + container: ubuntu:jammy strategy: fail-fast: false matrix: @@ -33,6 +32,17 @@ jobs: - '3.9' - '3.10' - '3.11' + + services: + postgres: + image: ankane/pgvector + env: + POSTGRES_PASSWORD: postgres + POSTGRES_USER: postgres + ports: + - 5432:5432 + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + steps: - uses: actions/checkout@v3 with: @@ -43,17 +53,37 @@ jobs: with: python-version: ${{ matrix.python_version }} - - name: ⏬️ Install Dependencies + - name: Install Git run: | - sudo apt update && sudo apt install -y libegl1 + apt update && apt install -y git + + - name: ⏬️ Install Dependencies + env: + DEBIAN_FRONTEND: noninteractive + run: | + apt update && apt install -y libegl1 sqlite3 libsqlite3-dev libsqlite3-0 + + - name: ⬇️ Install Postgres + env: + DEBIAN_FRONTEND: noninteractive + run : | + apt install -y postgresql postgresql-client && apt install -y postgresql-server-dev-14 + + - name: ⬇️ Install pip + run: | + apt install -y python3-pip + python -m ensurepip --upgrade python -m pip install --upgrade pip - name: ⬇️ Install Application - run: pip install --upgrade .[dev] - - - name: 🌡️ Validate Application - run: pre-commit run --hook-stage manual --all + run: sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && pip install --upgrade .[dev] - name: 🧪 Test Application + env: + POSTGRES_HOST: postgres + POSTGRES_PORT: 5432 + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres run: pytest timeout-minutes: 10 diff --git a/Dockerfile b/Dockerfile index af271537..9882a236 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,13 +8,20 @@ RUN apt update -y && apt -y install python3-pip git WORKDIR /app # Install Application -COPY . . +COPY pyproject.toml . +COPY README.md . RUN sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && \ pip install --no-cache-dir . +# Copy Source Code +COPY . . + +# Set the PYTHONPATH environment variable in order for it to find the Django app. +ENV PYTHONPATH=/app/src:$PYTHONPATH + # Run the Application # There are more arguments required for the application to run, # but these should be passed in through the docker-compose.yml file. ARG PORT EXPOSE ${PORT} -ENTRYPOINT ["khoj"] +ENTRYPOINT ["python3", "src/khoj/main.py"] diff --git a/docker-compose.yml b/docker-compose.yml index 5f1bb1f9..d6048916 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,21 @@ version: "3.9" services: + database: + image: ankane/pgvector + ports: + - "5432:5432" + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + volumes: + - khoj_db:/var/lib/postgresql/data/ server: + # Use the following line to use the latest version of khoj. Otherwise, it will build from source. image: ghcr.io/khoj-ai/khoj:latest + # Uncomment the following line to build from source. This will take a few minutes. Comment the next two lines out if you want to use the offiicial image. + # build: + # context: . ports: # If changing the local port (left hand side), no other changes required. # If changing the remote port (right hand side), @@ -26,8 +40,15 @@ services: - ./tests/data/models/:/root/.khoj/search/ - khoj_config:/root/.khoj/ # Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/ + environment: + - POSTGRES_DB=postgres + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + - POSTGRES_HOST=database + - POSTGRES_PORT=5432 command: --host="0.0.0.0" --port=42110 -vv volumes: khoj_config: + khoj_db: diff --git a/pyproject.toml b/pyproject.toml index 34f15d4b..8732d47a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,13 +59,15 @@ dependencies = [ "requests >= 2.26.0", "bs4 >= 0.0.1", "anyio == 3.7.1", - "pymupdf >= 1.23.3", + "pymupdf >= 1.23.5", "django == 4.2.5", "authlib == 1.2.1", "gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'", "gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'", "itsdangerous == 2.1.2", "httpx == 0.25.0", + "pgvector == 0.2.3", + "psycopg2-binary == 2.9.9", ] dynamic = ["version"] @@ -91,6 +93,8 @@ dev = [ "mypy >= 1.0.1", "black >= 23.1.0", "pre-commit >= 3.0.4", + "pytest-django == 4.5.2", + "pytest-asyncio == 0.21.1", ] [tool.hatch.version] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..eec111ec --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +DJANGO_SETTINGS_MODULE = app.settings +pythonpath = . src +testpaths = tests diff --git a/src/app/README.md b/src/app/README.md new file mode 100644 index 00000000..7a93ee8b --- /dev/null +++ b/src/app/README.md @@ -0,0 +1,60 @@ +# Django App + +Khoj uses Django as the backend framework primarily for its powerful ORM and the admin interface. The Django app is located in the `src/app` directory. We have one installed app, under the `/database/` directory. This app is responsible for all the database related operations and holds all of our models. You can find the extensive Django documentation [here](https://docs.djangoproject.com/en/4.2/) 🌈. + +## Setup (Docker) + +### Prerequisites +1. Ensure you have [Docker](https://docs.docker.com/get-docker/) installed. +2. Ensure you have [Docker Compose](https://docs.docker.com/compose/install/) installed. + +### Run + +Using the `docker-compose.yml` file in the root directory, you can run the Khoj app using the following command: +```bash +docker-compose up +``` + +## Setup (Local) + +### Install dependencies + +```bash +pip install -e '.[dev]' +``` + +### Setup the database + +1. Ensure you have Postgres installed. For MacOS, you can use [Postgres.app](https://postgresapp.com/). +2. If you're not using Postgres.app, you may have to install the pgvector extension manually. You can find the instructions [here](https://github.com/pgvector/pgvector#installation). If you're using Postgres.app, you can skip this step. Reproduced instructions below for convenience. + +```bash +cd /tmp +git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git +cd pgvector +make +make install # may need sudo +``` +3. Create a database + +### Make migrations + +This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted). + +```bash +python3 src/manage.py makemigrations +``` + +### Run migrations + +This command will run any pending migrations in your application. +```bash +python3 src/manage.py migrate +``` + +### Run the server + +While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend. +```bash +python3 src/khoj/main.py +``` diff --git a/src/app/settings.py b/src/app/settings.py index 74c496a7..cfd7cd3c 100644 --- a/src/app/settings.py +++ b/src/app/settings.py @@ -77,8 +77,12 @@ WSGI_APPLICATION = "app.wsgi.application" DATABASES = { "default": { - "ENGINE": "django.db.backends.sqlite3", - "NAME": BASE_DIR / "db.sqlite3", + "ENGINE": "django.db.backends.postgresql", + "HOST": os.getenv("POSTGRES_HOST", "localhost"), + "PORT": os.getenv("POSTGRES_PORT", "5432"), + "USER": os.getenv("POSTGRES_USER", "postgres"), + "NAME": os.getenv("POSTGRES_DB", "khoj"), + "PASSWORD": os.getenv("POSTGRES_PASSWORD", "postgres"), } } diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index a72323ae..a7c1c6f9 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -1,15 +1,30 @@ -from typing import Type, TypeVar +from typing import Type, TypeVar, List import uuid +from datetime import date from django.db import models from django.contrib.sessions.backends.db import SessionStore +from pgvector.django import CosineDistance +from django.db.models.manager import BaseManager +from django.db.models import Q +from torch import Tensor # Import sync_to_async from Django Channels from asgiref.sync import sync_to_async from fastapi import HTTPException -from database.models import KhojUser, GoogleUser, NotionConfig +from database.models import ( + KhojUser, + GoogleUser, + NotionConfig, + GithubConfig, + Embeddings, + GithubRepoConfig, +) +from khoj.search_filter.word_filter import WordFilter +from khoj.search_filter.file_filter import FileFilter +from khoj.search_filter.date_filter import DateFilter ModelType = TypeVar("ModelType", bound=models.Model) @@ -40,9 +55,7 @@ async def get_or_create_user(token: dict) -> KhojUser: async def create_google_user(token: dict) -> KhojUser: user_info = token.get("userinfo") - user = await KhojUser.objects.acreate( - username=user_info.get("email"), email=user_info.get("email"), uuid=uuid.uuid4() - ) + user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email")) await user.asave() await GoogleUser.objects.acreate( sub=user_info.get("sub"), @@ -76,3 +89,149 @@ async def retrieve_user(session_id: str) -> KhojUser: if not user: raise HTTPException(status_code=401, detail="Invalid user") return user + + +def get_all_users() -> BaseManager[KhojUser]: + return KhojUser.objects.all() + + +def get_user_github_config(user: KhojUser): + config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() + if not config: + return None + return config + + +def get_user_notion_config(user: KhojUser): + config = NotionConfig.objects.filter(user=user).first() + if not config: + return None + return config + + +async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config): + deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None + deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None + await object.objects.filter(user=user).adelete() + await object.objects.acreate( + input_files=deduped_files, + input_filter=deduped_filters, + index_heading_entries=updated_config.index_heading_entries, + user=user, + ) + + +async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): + config = await GithubConfig.objects.filter(user=user).afirst() + + if not config: + config = await GithubConfig.objects.acreate(pat_token=pat_token, user=user) + else: + config.pat_token = pat_token + await config.asave() + await config.githubrepoconfig.all().adelete() + + for repo in repos: + await GithubRepoConfig.objects.acreate( + name=repo["name"], owner=repo["owner"], branch=repo["branch"], github_config=config + ) + return config + + +class EmbeddingsAdapters: + word_filer = WordFilter() + file_filter = FileFilter() + date_filter = DateFilter() + + @staticmethod + def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool: + return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists() + + @staticmethod + def delete_embedding_by_file(user: KhojUser, file_path: str): + deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete() + return deleted_count + + @staticmethod + def delete_all_embeddings(user: KhojUser, file_type: str): + deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete() + return deleted_count + + @staticmethod + def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str): + return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True) + + @staticmethod + def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]): + Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete() + + @staticmethod + def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date): + return embeddings.filter( + embeddingsdates__date__gte=start_date, + embeddingsdates__date__lte=end_date, + ) + + @staticmethod + async def user_has_embeddings(user: KhojUser): + return await Embeddings.objects.filter(user=user).aexists() + + @staticmethod + def apply_filters(user: KhojUser, query: str, file_type_filter: str = None): + q_filter_terms = Q() + + explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query) + file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query) + date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query) + + if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0: + return Embeddings.objects.filter(user=user) + + for term in explicit_word_terms: + if term.startswith("+"): + q_filter_terms &= Q(raw__icontains=term[1:]) + elif term.startswith("-"): + q_filter_terms &= ~Q(raw__icontains=term[1:]) + + q_file_filter_terms = Q() + + if len(file_filters) > 0: + for term in file_filters: + q_file_filter_terms |= Q(file_path__regex=term) + + q_filter_terms &= q_file_filter_terms + + if len(date_filters) > 0: + min_date, max_date = date_filters + if min_date is not None: + # Convert the min_date timestamp to yyyy-mm-dd format + formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d") + q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date) + if max_date is not None: + # Convert the max_date timestamp to yyyy-mm-dd format + formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d") + q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date) + + relevant_embeddings = Embeddings.objects.filter(user=user).filter( + q_filter_terms, + ) + if file_type_filter: + relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter) + return relevant_embeddings + + @staticmethod + def search_with_embeddings( + user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None + ): + relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter) + relevant_embeddings = relevant_embeddings.filter(user=user).annotate( + distance=CosineDistance("embeddings", embeddings) + ) + if file_type_filter: + relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter) + relevant_embeddings = relevant_embeddings.order_by("distance") + return relevant_embeddings[:max_results] + + @staticmethod + def get_unique_file_types(user: KhojUser): + return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct() diff --git a/src/database/migrations/0003_user_khoj_configurations_and_more.py b/src/database/migrations/0003_user_khoj_configurations_and_more.py deleted file mode 100644 index 537ba4c4..00000000 --- a/src/database/migrations/0003_user_khoj_configurations_and_more.py +++ /dev/null @@ -1,79 +0,0 @@ -# Generated by Django 4.2.5 on 2023-09-27 17:52 - -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion -import uuid - - -class Migration(migrations.Migration): - dependencies = [ - ("database", "0002_googleuser"), - ] - - operations = [ - migrations.CreateModel( - name="Configuration", - fields=[ - ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), - ], - ), - migrations.CreateModel( - name="ConversationProcessorConfig", - fields=[ - ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), - ("conversation", models.JSONField()), - ("enable_offline_chat", models.BooleanField(default=False)), - ], - ), - migrations.CreateModel( - name="GithubConfig", - fields=[ - ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), - ("pat_token", models.CharField(max_length=200)), - ("compressed_jsonl", models.CharField(max_length=300)), - ("embeddings_file", models.CharField(max_length=300)), - ( - "config", - models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"), - ), - ], - ), - migrations.AddField( - model_name="khojuser", - name="uuid", - field=models.UUIDField(verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)), - preserve_default=False, - ), - migrations.CreateModel( - name="NotionConfig", - fields=[ - ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), - ("token", models.CharField(max_length=200)), - ("compressed_jsonl", models.CharField(max_length=300)), - ("embeddings_file", models.CharField(max_length=300)), - ( - "config", - models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"), - ), - ], - ), - migrations.CreateModel( - name="GithubRepoConfig", - fields=[ - ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), - ("name", models.CharField(max_length=200)), - ("owner", models.CharField(max_length=200)), - ("branch", models.CharField(max_length=200)), - ( - "github_config", - models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.githubconfig"), - ), - ], - ), - migrations.AddField( - model_name="configuration", - name="user", - field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), - ), - ] diff --git a/src/database/migrations/0003_vector_extension.py b/src/database/migrations/0003_vector_extension.py new file mode 100644 index 00000000..9de01df2 --- /dev/null +++ b/src/database/migrations/0003_vector_extension.py @@ -0,0 +1,10 @@ +from django.db import migrations +from pgvector.django import VectorExtension + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0002_googleuser"), + ] + + operations = [VectorExtension()] diff --git a/src/database/migrations/0004_conversationprocessorconfig_githubconfig_and_more.py b/src/database/migrations/0004_conversationprocessorconfig_githubconfig_and_more.py new file mode 100644 index 00000000..294fc620 --- /dev/null +++ b/src/database/migrations/0004_conversationprocessorconfig_githubconfig_and_more.py @@ -0,0 +1,193 @@ +# Generated by Django 4.2.5 on 2023-10-11 22:24 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import pgvector.django +import uuid + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0003_vector_extension"), + ] + + operations = [ + migrations.CreateModel( + name="ConversationProcessorConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("conversation", models.JSONField()), + ("enable_offline_chat", models.BooleanField(default=False)), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="GithubConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("pat_token", models.CharField(max_length=200)), + ], + options={ + "abstract": False, + }, + ), + migrations.AddField( + model_name="khojuser", + name="uuid", + field=models.UUIDField(default=1234, verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)), + preserve_default=False, + ), + migrations.CreateModel( + name="NotionConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("token", models.CharField(max_length=200)), + ("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="LocalPlaintextConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("input_files", models.JSONField(default=list, null=True)), + ("input_filter", models.JSONField(default=list, null=True)), + ("index_heading_entries", models.BooleanField(default=False)), + ("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="LocalPdfConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("input_files", models.JSONField(default=list, null=True)), + ("input_filter", models.JSONField(default=list, null=True)), + ("index_heading_entries", models.BooleanField(default=False)), + ("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="LocalOrgConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("input_files", models.JSONField(default=list, null=True)), + ("input_filter", models.JSONField(default=list, null=True)), + ("index_heading_entries", models.BooleanField(default=False)), + ("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="LocalMarkdownConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("input_files", models.JSONField(default=list, null=True)), + ("input_filter", models.JSONField(default=list, null=True)), + ("index_heading_entries", models.BooleanField(default=False)), + ("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="GithubRepoConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("name", models.CharField(max_length=200)), + ("owner", models.CharField(max_length=200)), + ("branch", models.CharField(max_length=200)), + ( + "github_config", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="githubrepoconfig", + to="database.githubconfig", + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.AddField( + model_name="githubconfig", + name="user", + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), + ), + migrations.CreateModel( + name="Embeddings", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("embeddings", pgvector.django.VectorField(dimensions=384)), + ("raw", models.TextField()), + ("compiled", models.TextField()), + ("heading", models.CharField(blank=True, default=None, max_length=1000, null=True)), + ( + "file_type", + models.CharField( + choices=[ + ("image", "Image"), + ("pdf", "Pdf"), + ("plaintext", "Plaintext"), + ("markdown", "Markdown"), + ("org", "Org"), + ("notion", "Notion"), + ("github", "Github"), + ("conversation", "Conversation"), + ], + default="plaintext", + max_length=30, + ), + ), + ("file_path", models.CharField(blank=True, default=None, max_length=400, null=True)), + ("file_name", models.CharField(blank=True, default=None, max_length=400, null=True)), + ("url", models.URLField(blank=True, default=None, max_length=400, null=True)), + ("hashed_value", models.CharField(max_length=100)), + ( + "user", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/database/migrations/0005_embeddings_corpus_id.py b/src/database/migrations/0005_embeddings_corpus_id.py new file mode 100644 index 00000000..47f5aa8c --- /dev/null +++ b/src/database/migrations/0005_embeddings_corpus_id.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.5 on 2023-10-13 02:39 + +from django.db import migrations, models +import uuid + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0004_conversationprocessorconfig_githubconfig_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="embeddings", + name="corpus_id", + field=models.UUIDField(default=uuid.uuid4, editable=False), + ), + ] diff --git a/src/database/migrations/0006_embeddingsdates.py b/src/database/migrations/0006_embeddingsdates.py new file mode 100644 index 00000000..9d988ed8 --- /dev/null +++ b/src/database/migrations/0006_embeddingsdates.py @@ -0,0 +1,33 @@ +# Generated by Django 4.2.5 on 2023-10-13 19:28 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0005_embeddings_corpus_id"), + ] + + operations = [ + migrations.CreateModel( + name="EmbeddingsDates", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("date", models.DateField()), + ( + "embeddings", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="embeddings_dates", + to="database.embeddings", + ), + ), + ], + options={ + "indexes": [models.Index(fields=["date"], name="database_em_date_a1ba47_idx")], + }, + ), + ] diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index 6536671b..9a50d94f 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -2,11 +2,25 @@ import uuid from django.db import models from django.contrib.auth.models import AbstractUser +from pgvector.django import VectorField + + +class BaseModel(models.Model): + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + abstract = True class KhojUser(AbstractUser): uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False)) + def save(self, *args, **kwargs): + if not self.uuid: + self.uuid = uuid.uuid4() + super().save(*args, **kwargs) + class GoogleUser(models.Model): user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) @@ -23,31 +37,85 @@ class GoogleUser(models.Model): return self.name -class Configuration(models.Model): - user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) - - -class NotionConfig(models.Model): +class NotionConfig(BaseModel): token = models.CharField(max_length=200) - compressed_jsonl = models.CharField(max_length=300) - embeddings_file = models.CharField(max_length=300) - config = models.OneToOneField(Configuration, on_delete=models.CASCADE) + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) -class GithubConfig(models.Model): +class GithubConfig(BaseModel): pat_token = models.CharField(max_length=200) - compressed_jsonl = models.CharField(max_length=300) - embeddings_file = models.CharField(max_length=300) - config = models.OneToOneField(Configuration, on_delete=models.CASCADE) + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) -class GithubRepoConfig(models.Model): +class GithubRepoConfig(BaseModel): name = models.CharField(max_length=200) owner = models.CharField(max_length=200) branch = models.CharField(max_length=200) - github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE) + github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig") -class ConversationProcessorConfig(models.Model): +class LocalOrgConfig(BaseModel): + input_files = models.JSONField(default=list, null=True) + input_filter = models.JSONField(default=list, null=True) + index_heading_entries = models.BooleanField(default=False) + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + + +class LocalMarkdownConfig(BaseModel): + input_files = models.JSONField(default=list, null=True) + input_filter = models.JSONField(default=list, null=True) + index_heading_entries = models.BooleanField(default=False) + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + + +class LocalPdfConfig(BaseModel): + input_files = models.JSONField(default=list, null=True) + input_filter = models.JSONField(default=list, null=True) + index_heading_entries = models.BooleanField(default=False) + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + + +class LocalPlaintextConfig(BaseModel): + input_files = models.JSONField(default=list, null=True) + input_filter = models.JSONField(default=list, null=True) + index_heading_entries = models.BooleanField(default=False) + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + + +class ConversationProcessorConfig(BaseModel): conversation = models.JSONField() enable_offline_chat = models.BooleanField(default=False) + + +class Embeddings(BaseModel): + class EmbeddingsType(models.TextChoices): + IMAGE = "image" + PDF = "pdf" + PLAINTEXT = "plaintext" + MARKDOWN = "markdown" + ORG = "org" + NOTION = "notion" + GITHUB = "github" + CONVERSATION = "conversation" + + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True) + embeddings = VectorField(dimensions=384) + raw = models.TextField() + compiled = models.TextField() + heading = models.CharField(max_length=1000, default=None, null=True, blank=True) + file_type = models.CharField(max_length=30, choices=EmbeddingsType.choices, default=EmbeddingsType.PLAINTEXT) + file_path = models.CharField(max_length=400, default=None, null=True, blank=True) + file_name = models.CharField(max_length=400, default=None, null=True, blank=True) + url = models.URLField(max_length=400, default=None, null=True, blank=True) + hashed_value = models.CharField(max_length=100) + corpus_id = models.UUIDField(default=uuid.uuid4, editable=False) + + +class EmbeddingsDates(BaseModel): + date = models.DateField() + embeddings = models.ForeignKey(Embeddings, on_delete=models.CASCADE, related_name="embeddings_dates") + + class Meta: + indexes = [ + models.Index(fields=["date"]), + ] diff --git a/src/database/tests.py b/src/database/tests.py new file mode 100644 index 00000000..7ce503c2 --- /dev/null +++ b/src/database/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 5f60c663..f65b1056 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -30,6 +30,8 @@ from khoj.utils.helpers import resolve_absolute_path, merge_dicts from khoj.utils.fs_syncer import collect_files from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig from khoj.routers.indexer import configure_content, load_content, configure_search +from database.models import KhojUser +from database.adapters import get_all_users logger = logging.getLogger(__name__) @@ -48,14 +50,28 @@ class UserAuthenticationBackend(AuthenticationBackend): from database.models import KhojUser self.khojuser_manager = KhojUser.objects + self._initialize_default_user() super().__init__() + def _initialize_default_user(self): + if not self.khojuser_manager.filter(username="default").exists(): + self.khojuser_manager.create_user( + username="default", + email="default@example.com", + password="default", + ) + async def authenticate(self, request): current_user = request.session.get("user") if current_user and current_user.get("email"): user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst() if user: return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + elif not state.anonymous_mode: + user = await self.khojuser_manager.filter(username="default").afirst() + if user: + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + return AuthCredentials(), UnauthenticatedUser() @@ -78,7 +94,11 @@ def initialize_server(config: Optional[FullConfig]): def configure_server( - config: FullConfig, regenerate: bool = False, search_type: Optional[SearchType] = None, init=False + config: FullConfig, + regenerate: bool = False, + search_type: Optional[SearchType] = None, + init=False, + user: KhojUser = None, ): # Update Config state.config = config @@ -95,7 +115,7 @@ def configure_server( state.config_lock.acquire() state.SearchType = configure_search_types(state.config) state.search_models = configure_search(state.search_models, state.config.search_type) - initialize_content(regenerate, search_type, init) + initialize_content(regenerate, search_type, init, user) except Exception as e: logger.error(f"🚨 Failed to configure search models", exc_info=True) raise e @@ -103,7 +123,7 @@ def configure_server( state.config_lock.release() -def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False): +def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None): # Initialize Content from Config if state.search_models: try: @@ -112,7 +132,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non state.content_index = load_content(state.config.content_type, state.content_index, state.search_models) else: logger.info("📬 Updating content index...") - all_files = collect_files(state.config.content_type) + all_files = collect_files(user=user) state.content_index = configure_content( state.content_index, state.config.content_type, @@ -120,6 +140,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non state.search_models, regenerate, search_type, + user=user, ) except Exception as e: logger.error(f"🚨 Failed to index content", exc_info=True) @@ -152,9 +173,14 @@ if not state.demo: def update_search_index(): try: logger.info("📬 Updating content index via Scheduler") - all_files = collect_files(state.config.content_type) + for user in get_all_users(): + all_files = collect_files(user=user) + state.content_index = configure_content( + state.content_index, state.config.content_type, all_files, state.search_models, user=user + ) + all_files = collect_files(user=None) state.content_index = configure_content( - state.content_index, state.config.content_type, all_files, state.search_models + state.content_index, state.config.content_type, all_files, state.search_models, user=None ) logger.info("📪 Content index updated via Scheduler") except Exception as e: @@ -164,13 +190,9 @@ if not state.demo: def configure_search_types(config: FullConfig): # Extract core search types core_search_types = {e.name: e.value for e in SearchType} - # Extract configured plugin search types - plugin_search_types = {} - if config.content_type and config.content_type.plugins: - plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()} # Dynamically generate search type enum by merging core search types with configured plugin search types - return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types)) + return Enum("SearchType", merge_dicts(core_search_types, {})) def configure_processor( diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index d41ca26b..6e3a0223 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -10,12 +10,10 @@ Github

Github - {% if current_config.content_type.github %} - {% if current_model_state.github == False %} - Not Configured - {% else %} - Configured - {% endif %} + {% if current_model_state.github == False %} + Not Configured + {% else %} + Configured {% endif %}

@@ -24,7 +22,7 @@
- {% if current_config.content_type.github %} + {% if current_model_state.github %} Update {% else %} Setup @@ -32,7 +30,7 @@
- {% if current_config.content_type.github %} + {% if current_model_state.github %}
@@ -59,7 +55,7 @@
- {% if current_config.content_type.content %} + {% if current_model_state.content %} Update {% else %} Setup @@ -67,7 +63,7 @@
- {% if current_config.content_type.notion %} + {% if current_model_state.notion %}
- {% if current_config.content_type.markdown %} + {% if current_model_state.markdown %} Update {% else %} Setup @@ -102,7 +98,7 @@
- {% if current_config.content_type.markdown %} + {% if current_model_state.markdown %}
- {% if current_config.content_type.org %} + {% if current_model_state.org %} Update {% else %} Setup @@ -137,7 +133,7 @@
- {% if current_config.content_type.org %} + {% if current_model_state.org %}
- {% if current_config.content_type.pdf %} + {% if current_model_state.pdf %} Update {% else %} Setup @@ -172,7 +168,7 @@
- {% if current_config.content_type.pdf %} + {% if current_model_state.pdf %}
- {% if current_config.content_type.plaintext %} + {% if current_model_state.plaintext %} Update {% else %} Setup @@ -207,7 +203,7 @@
- {% if current_config.content_type.plaintext %} + {% if current_model_state.plaintext %}
- - - - - - - - - -
- - - -
- - - -
@@ -107,8 +89,6 @@ submit.addEventListener("click", function(event) { event.preventDefault(); - const compressed_jsonl = document.getElementById("compressed-jsonl").value; - const embeddings_file = document.getElementById("embeddings-file").value; const pat_token = document.getElementById("pat-token").value; if (pat_token == "") { @@ -154,8 +134,6 @@ body: JSON.stringify({ "pat_token": pat_token, "repos": repos, - "compressed_jsonl": compressed_jsonl, - "embeddings_file": embeddings_file, }) }) .then(response => response.json()) diff --git a/src/khoj/interface/web/content_type_input.html b/src/khoj/interface/web/content_type_input.html index 1f0dfa76..f8751ddc 100644 --- a/src/khoj/interface/web/content_type_input.html +++ b/src/khoj/interface/web/content_type_input.html @@ -43,33 +43,6 @@ - - - - - - - - - - - - - - -
- - - -
- - - -
- - - -
@@ -155,9 +128,8 @@ inputFilter = null; } - var compressed_jsonl = document.getElementById("compressed-jsonl").value; - var embeddings_file = document.getElementById("embeddings-file").value; - var index_heading_entries = document.getElementById("index-heading-entries").value; + // var index_heading_entries = document.getElementById("index-heading-entries").value; + var index_heading_entries = true; const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; fetch('/api/config/data/content_type/{{ content_type }}', { @@ -169,8 +141,6 @@ body: JSON.stringify({ "input_files": inputFiles, "input_filter": inputFilter, - "compressed_jsonl": compressed_jsonl, - "embeddings_file": embeddings_file, "index_heading_entries": index_heading_entries }) }) diff --git a/src/khoj/interface/web/content_type_notion_input.html b/src/khoj/interface/web/content_type_notion_input.html index dde5f2df..965c1ef5 100644 --- a/src/khoj/interface/web/content_type_notion_input.html +++ b/src/khoj/interface/web/content_type_notion_input.html @@ -20,24 +20,6 @@ - - - - - - - - - -
- - - -
- - - -
@@ -51,8 +33,6 @@ submit.addEventListener("click", function(event) { event.preventDefault(); - const compressed_jsonl = document.getElementById("compressed-jsonl").value; - const embeddings_file = document.getElementById("embeddings-file").value; const token = document.getElementById("token").value; if (token == "") { @@ -70,8 +50,6 @@ }, body: JSON.stringify({ "token": token, - "compressed_jsonl": compressed_jsonl, - "embeddings_file": embeddings_file, }) }) .then(response => response.json()) diff --git a/src/khoj/interface/web/index.html b/src/khoj/interface/web/index.html index 581ed9b8..ccf1ca71 100644 --- a/src/khoj/interface/web/index.html +++ b/src/khoj/interface/web/index.html @@ -172,7 +172,7 @@ url = createRequestUrl(query, type, results_count || 5, rerank); fetch(url, { headers: { - "X-CSRFToken": csrfToken + "Content-Type": "application/json" } }) .then(response => response.json()) @@ -199,8 +199,8 @@ fetch("/api/config/types") .then(response => response.json()) .then(enabled_types => { - // Show warning if no content types are enabled - if (enabled_types.detail) { + // Show warning if no content types are enabled, or just one ("all") + if (enabled_types[0] === "all" && enabled_types.length === 1) { document.getElementById("results").innerHTML = "
To use Khoj search, setup your content plugins on the Khoj settings page.
"; document.getElementById("query").setAttribute("disabled", "disabled"); document.getElementById("query").setAttribute("placeholder", "Configure Khoj to enable search"); diff --git a/src/app/main.py b/src/khoj/main.py similarity index 88% rename from src/app/main.py rename to src/khoj/main.py index 16f7cced..a713cc97 100644 --- a/src/app/main.py +++ b/src/khoj/main.py @@ -24,10 +24,15 @@ from rich.logging import RichHandler from django.core.asgi import get_asgi_application from django.core.management import call_command -# Internal Packages -from khoj.configure import configure_routes, initialize_server, configure_middleware -from khoj.utils import state -from khoj.utils.cli import cli +# Initialize Django +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings") +django.setup() + +# Initialize Django Database +call_command("migrate", "--noinput") + +# Initialize Django Static Files +call_command("collectstatic", "--noinput") # Initialize Django os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings") @@ -54,6 +59,11 @@ app.add_middleware( # Set Locale locale.setlocale(locale.LC_ALL, "") +# Internal Packages. We do this after setting up Django so that Django features are accessible to the app. +from khoj.configure import configure_routes, initialize_server, configure_middleware +from khoj.utils import state +from khoj.utils.cli import cli + # Setup Logger rich_handler = RichHandler(rich_tracebacks=True) rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]")) @@ -95,6 +105,8 @@ def run(): # Mount Django and Static Files app.mount("/django", django_app, name="django") + if not os.path.exists("static"): + os.mkdir("static") app.mount("/static", StaticFiles(directory="static"), name="static") # Configure Middleware @@ -111,6 +123,7 @@ def set_state(args): state.host = args.host state.port = args.port state.demo = args.demo + state.anonymous_mode = args.anonymous_mode state.khoj_version = version("khoj-assistant") diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py new file mode 100644 index 00000000..f0e2df77 --- /dev/null +++ b/src/khoj/processor/embeddings.py @@ -0,0 +1,57 @@ +from typing import List + +import torch +from langchain.embeddings import HuggingFaceEmbeddings +from sentence_transformers import CrossEncoder + +from khoj.utils.rawconfig import SearchResponse + + +class EmbeddingsModel: + def __init__(self): + self.model_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" + encode_kwargs = {"normalize_embeddings": True} + # encode_kwargs = {} + + if torch.cuda.is_available(): + # Use CUDA GPU + device = torch.device("cuda:0") + elif torch.backends.mps.is_available(): + # Use Apple M1 Metal Acceleration + device = torch.device("mps") + else: + device = torch.device("cpu") + + model_kwargs = {"device": device} + self.embeddings_model = HuggingFaceEmbeddings( + model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs + ) + + def embed_query(self, query): + return self.embeddings_model.embed_query(query) + + def embed_documents(self, docs): + return self.embeddings_model.embed_documents(docs) + + +class CrossEncoderModel: + def __init__(self): + self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2" + + if torch.cuda.is_available(): + # Use CUDA GPU + device = torch.device("cuda:0") + + elif torch.backends.mps.is_available(): + # Use Apple M1 Metal Acceleration + device = torch.device("mps") + + else: + device = torch.device("cpu") + + self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=device) + + def predict(self, query, hits: List[SearchResponse]): + cross__inp = [[query, hit.additional["compiled"]] for hit in hits] + cross_scores = self.cross_encoder_model.predict(cross__inp) + return cross_scores diff --git a/src/khoj/processor/github/github_to_jsonl.py b/src/khoj/processor/github/github_to_jsonl.py index bcd2e530..8feb6a31 100644 --- a/src/khoj/processor/github/github_to_jsonl.py +++ b/src/khoj/processor/github/github_to_jsonl.py @@ -2,7 +2,7 @@ import logging import time from datetime import datetime -from typing import Dict, List, Union +from typing import Dict, List, Union, Tuple # External Packages import requests @@ -12,18 +12,31 @@ from khoj.utils.helpers import timer from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl -from khoj.processor.text_to_jsonl import TextToJsonl -from khoj.utils.jsonl import compress_jsonl_data +from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.rawconfig import Entry +from database.models import Embeddings, GithubConfig, KhojUser logger = logging.getLogger(__name__) -class GithubToJsonl(TextToJsonl): - def __init__(self, config: GithubContentConfig): +class GithubToJsonl(TextEmbeddings): + def __init__(self, config: GithubConfig): super().__init__(config) - self.config = config + raw_repos = config.githubrepoconfig.all() + repos = [] + for repo in raw_repos: + repos.append( + GithubRepoConfig( + name=repo.name, + owner=repo.owner, + branch=repo.branch, + ) + ) + self.config = GithubContentConfig( + pat_token=config.pat_token, + repos=repos, + ) self.session = requests.Session() self.session.headers.update({"Authorization": f"token {self.config.pat_token}"}) @@ -37,7 +50,9 @@ class GithubToJsonl(TextToJsonl): else: return - def process(self, previous_entries=[], files=None, full_corpus=True): + def process( + self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False + ) -> Tuple[int, int]: if self.config.pat_token is None or self.config.pat_token == "": logger.error(f"Github PAT token is not set. Skipping github content") raise ValueError("Github PAT token is not set. Skipping github content") @@ -45,7 +60,7 @@ class GithubToJsonl(TextToJsonl): for repo in self.config.repos: current_entries += self.process_repo(repo) - return self.update_entries_with_ids(current_entries, previous_entries) + return self.update_entries_with_ids(current_entries, user=user) def process_repo(self, repo: GithubRepoConfig): repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}" @@ -80,26 +95,18 @@ class GithubToJsonl(TextToJsonl): current_entries += issue_entries with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger): - current_entries = TextToJsonl.split_entries_by_max_tokens(current_entries, max_tokens=256) + current_entries = TextEmbeddings.split_entries_by_max_tokens(current_entries, max_tokens=256) return current_entries - def update_entries_with_ids(self, current_entries, previous_entries): + def update_entries_with_ids(self, current_entries, user: KhojUser = None): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): - entries_with_ids = TextToJsonl.mark_entries_for_update( - current_entries, previous_entries, key="compiled", logger=logger + num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + current_entries, Embeddings.EmbeddingsType.GITHUB, key="compiled", logger=logger, user=user ) - with timer("Write github entries to JSONL file", logger): - # Process Each Entry from All Notes Files - entries = list(map(lambda entry: entry[1], entries_with_ids)) - jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries) - - # Compress JSONL formatted Data - compress_jsonl_data(jsonl_data, self.config.compressed_jsonl) - - return entries_with_ids + return num_new_embeddings, num_deleted_embeddings def get_files(self, repo_url: str, repo: GithubRepoConfig): # Get the contents of the repository diff --git a/src/khoj/processor/jsonl/__init__.py b/src/khoj/processor/jsonl/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/khoj/processor/jsonl/jsonl_to_jsonl.py b/src/khoj/processor/jsonl/jsonl_to_jsonl.py deleted file mode 100644 index 4a6fab99..00000000 --- a/src/khoj/processor/jsonl/jsonl_to_jsonl.py +++ /dev/null @@ -1,91 +0,0 @@ -# Standard Packages -import glob -import logging -from pathlib import Path -from typing import List - -# Internal Packages -from khoj.processor.text_to_jsonl import TextToJsonl -from khoj.utils.helpers import get_absolute_path, timer -from khoj.utils.jsonl import load_jsonl, compress_jsonl_data -from khoj.utils.rawconfig import Entry - - -logger = logging.getLogger(__name__) - - -class JsonlToJsonl(TextToJsonl): - # Define Functions - def process(self, previous_entries=[], files: dict[str, str] = {}, full_corpus: bool = True): - # Extract required fields from config - input_jsonl_files, input_jsonl_filter, output_file = ( - self.config.input_files, - self.config.input_filter, - self.config.compressed_jsonl, - ) - - # Get Jsonl Input Files to Process - all_input_jsonl_files = JsonlToJsonl.get_jsonl_files(input_jsonl_files, input_jsonl_filter) - - # Extract Entries from specified jsonl files - with timer("Parse entries from jsonl files", logger): - input_jsons = JsonlToJsonl.extract_jsonl_entries(all_input_jsonl_files) - current_entries = list(map(Entry.from_dict, input_jsons)) - - # Split entries by max tokens supported by model - with timer("Split entries by max token size supported by model", logger): - current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) - - # Identify, mark and merge any new entries with previous entries - with timer("Identify new or updated entries", logger): - entries_with_ids = TextToJsonl.mark_entries_for_update( - current_entries, previous_entries, key="compiled", logger=logger - ) - - with timer("Write entries to JSONL file", logger): - # Process Each Entry from All Notes Files - entries = list(map(lambda entry: entry[1], entries_with_ids)) - jsonl_data = JsonlToJsonl.convert_entries_to_jsonl(entries) - - # Compress JSONL formatted Data - compress_jsonl_data(jsonl_data, output_file) - - return entries_with_ids - - @staticmethod - def get_jsonl_files(jsonl_files=None, jsonl_file_filters=None): - "Get all jsonl files to process" - absolute_jsonl_files, filtered_jsonl_files = set(), set() - if jsonl_files: - absolute_jsonl_files = {get_absolute_path(jsonl_file) for jsonl_file in jsonl_files} - if jsonl_file_filters: - filtered_jsonl_files = { - filtered_file - for jsonl_file_filter in jsonl_file_filters - for filtered_file in glob.glob(get_absolute_path(jsonl_file_filter), recursive=True) - } - - all_jsonl_files = sorted(absolute_jsonl_files | filtered_jsonl_files) - - files_with_non_jsonl_extensions = { - jsonl_file for jsonl_file in all_jsonl_files if not jsonl_file.endswith(".jsonl") - } - if any(files_with_non_jsonl_extensions): - print(f"[Warning] There maybe non jsonl files in the input set: {files_with_non_jsonl_extensions}") - - logger.debug(f"Processing files: {all_jsonl_files}") - - return all_jsonl_files - - @staticmethod - def extract_jsonl_entries(jsonl_files): - "Extract entries from specified jsonl files" - entries = [] - for jsonl_file in jsonl_files: - entries.extend(load_jsonl(Path(jsonl_file))) - return entries - - @staticmethod - def convert_entries_to_jsonl(entries: List[Entry]): - "Convert each entry to JSON and collate as JSONL" - return "".join([f"{entry.to_json()}\n" for entry in entries]) diff --git a/src/khoj/processor/markdown/markdown_to_jsonl.py b/src/khoj/processor/markdown/markdown_to_jsonl.py index c2f0f0bf..17136b00 100644 --- a/src/khoj/processor/markdown/markdown_to_jsonl.py +++ b/src/khoj/processor/markdown/markdown_to_jsonl.py @@ -3,29 +3,28 @@ import logging import re import urllib3 from pathlib import Path -from typing import List +from typing import Tuple, List # Internal Packages -from khoj.processor.text_to_jsonl import TextToJsonl +from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.helpers import timer from khoj.utils.constants import empty_escape_sequences -from khoj.utils.jsonl import compress_jsonl_data -from khoj.utils.rawconfig import Entry, TextContentConfig +from khoj.utils.rawconfig import Entry +from database.models import Embeddings, KhojUser logger = logging.getLogger(__name__) -class MarkdownToJsonl(TextToJsonl): - def __init__(self, config: TextContentConfig): - super().__init__(config) - self.config = config +class MarkdownToJsonl(TextEmbeddings): + def __init__(self): + super().__init__() # Define Functions - def process(self, previous_entries=[], files=None, full_corpus: bool = True): + def process( + self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False + ) -> Tuple[int, int]: # Extract required fields from config - output_file = self.config.compressed_jsonl - if not full_corpus: deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names @@ -45,19 +44,17 @@ class MarkdownToJsonl(TextToJsonl): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): - entries_with_ids = TextToJsonl.mark_entries_for_update( - current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names + num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + current_entries, + Embeddings.EmbeddingsType.MARKDOWN, + "compiled", + logger, + deletion_file_names, + user, + regenerate=regenerate, ) - with timer("Write markdown entries to JSONL file", logger): - # Process Each Entry from All Notes Files - entries = list(map(lambda entry: entry[1], entries_with_ids)) - jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries) - - # Compress JSONL formatted Data - compress_jsonl_data(jsonl_data, output_file) - - return entries_with_ids + return num_new_embeddings, num_deleted_embeddings @staticmethod def extract_markdown_entries(markdown_files): diff --git a/src/khoj/processor/notion/notion_to_jsonl.py b/src/khoj/processor/notion/notion_to_jsonl.py index 0df56c37..0081350a 100644 --- a/src/khoj/processor/notion/notion_to_jsonl.py +++ b/src/khoj/processor/notion/notion_to_jsonl.py @@ -1,5 +1,6 @@ # Standard Packages import logging +from typing import Tuple # External Packages import requests @@ -7,9 +8,9 @@ import requests # Internal Packages from khoj.utils.helpers import timer from khoj.utils.rawconfig import Entry, NotionContentConfig -from khoj.processor.text_to_jsonl import TextToJsonl -from khoj.utils.jsonl import compress_jsonl_data +from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.rawconfig import Entry +from database.models import Embeddings, KhojUser, NotionConfig from enum import Enum @@ -49,10 +50,12 @@ class NotionBlockType(Enum): CALLOUT = "callout" -class NotionToJsonl(TextToJsonl): - def __init__(self, config: NotionContentConfig): +class NotionToJsonl(TextEmbeddings): + def __init__(self, config: NotionConfig): super().__init__(config) - self.config = config + self.config = NotionContentConfig( + token=config.token, + ) self.session = requests.Session() self.session.headers.update({"Authorization": f"Bearer {config.token}", "Notion-Version": "2022-02-22"}) self.unsupported_block_types = [ @@ -80,7 +83,9 @@ class NotionToJsonl(TextToJsonl): self.body_params = {"page_size": 100} - def process(self, previous_entries=[], files=None, full_corpus=True): + def process( + self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False + ) -> Tuple[int, int]: current_entries = [] # Get all pages @@ -112,7 +117,7 @@ class NotionToJsonl(TextToJsonl): page_entries = self.process_page(p_or_d) current_entries.extend(page_entries) - return self.update_entries_with_ids(current_entries, previous_entries) + return self.update_entries_with_ids(current_entries, user) def process_page(self, page): page_id = page["id"] @@ -241,19 +246,11 @@ class NotionToJsonl(TextToJsonl): title = None return title, content - def update_entries_with_ids(self, current_entries, previous_entries): + def update_entries_with_ids(self, current_entries, user: KhojUser = None): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): - entries_with_ids = TextToJsonl.mark_entries_for_update( - current_entries, previous_entries, key="compiled", logger=logger + num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + current_entries, Embeddings.EmbeddingsType.NOTION, key="compiled", logger=logger, user=user ) - with timer("Write Notion entries to JSONL file", logger): - # Process Each Entry from all Notion entries - entries = list(map(lambda entry: entry[1], entries_with_ids)) - jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries) - - # Compress JSONL formatted Data - compress_jsonl_data(jsonl_data, self.config.compressed_jsonl) - - return entries_with_ids + return num_new_embeddings, num_deleted_embeddings diff --git a/src/khoj/processor/org_mode/org_to_jsonl.py b/src/khoj/processor/org_mode/org_to_jsonl.py index 2f22add4..90fdc029 100644 --- a/src/khoj/processor/org_mode/org_to_jsonl.py +++ b/src/khoj/processor/org_mode/org_to_jsonl.py @@ -5,28 +5,26 @@ from typing import Iterable, List, Tuple # Internal Packages from khoj.processor.org_mode import orgnode -from khoj.processor.text_to_jsonl import TextToJsonl +from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.helpers import timer -from khoj.utils.jsonl import compress_jsonl_data -from khoj.utils.rawconfig import Entry, TextContentConfig +from khoj.utils.rawconfig import Entry from khoj.utils import state +from database.models import Embeddings, KhojUser logger = logging.getLogger(__name__) -class OrgToJsonl(TextToJsonl): - def __init__(self, config: TextContentConfig): - super().__init__(config) - self.config = config +class OrgToJsonl(TextEmbeddings): + def __init__(self): + super().__init__() # Define Functions def process( - self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True - ) -> List[Tuple[int, Entry]]: + self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False + ) -> Tuple[int, int]: # Extract required fields from config - output_file = self.config.compressed_jsonl - index_heading_entries = self.config.index_heading_entries + index_heading_entries = True if not full_corpus: deletion_file_names = set([file for file in files if files[file] == ""]) @@ -47,19 +45,17 @@ class OrgToJsonl(TextToJsonl): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): - entries_with_ids = TextToJsonl.mark_entries_for_update( - current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names + num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + current_entries, + Embeddings.EmbeddingsType.ORG, + "compiled", + logger, + deletion_file_names, + user, + regenerate=regenerate, ) - # Process Each Entry from All Notes Files - with timer("Write org entries to JSONL file", logger): - entries = map(lambda entry: entry[1], entries_with_ids) - jsonl_data = self.convert_org_entries_to_jsonl(entries) - - # Compress JSONL formatted Data - compress_jsonl_data(jsonl_data, output_file) - - return entries_with_ids + return num_new_embeddings, num_deleted_embeddings @staticmethod def extract_org_entries(org_files: dict[str, str]): diff --git a/src/khoj/processor/pdf/pdf_to_jsonl.py b/src/khoj/processor/pdf/pdf_to_jsonl.py index c24d9940..3a712c68 100644 --- a/src/khoj/processor/pdf/pdf_to_jsonl.py +++ b/src/khoj/processor/pdf/pdf_to_jsonl.py @@ -1,28 +1,31 @@ # Standard Packages import os import logging -from typing import List +from typing import List, Tuple import base64 # External Packages from langchain.document_loaders import PyMuPDFLoader # Internal Packages -from khoj.processor.text_to_jsonl import TextToJsonl +from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.helpers import timer -from khoj.utils.jsonl import compress_jsonl_data from khoj.utils.rawconfig import Entry +from database.models import Embeddings, KhojUser logger = logging.getLogger(__name__) -class PdfToJsonl(TextToJsonl): - # Define Functions - def process(self, previous_entries=[], files: dict[str, str] = None, full_corpus: bool = True): - # Extract required fields from config - output_file = self.config.compressed_jsonl +class PdfToJsonl(TextEmbeddings): + def __init__(self): + super().__init__() + # Define Functions + def process( + self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False + ) -> Tuple[int, int]: + # Extract required fields from config if not full_corpus: deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names @@ -40,19 +43,17 @@ class PdfToJsonl(TextToJsonl): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): - entries_with_ids = TextToJsonl.mark_entries_for_update( - current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names + num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + current_entries, + Embeddings.EmbeddingsType.PDF, + "compiled", + logger, + deletion_file_names, + user, + regenerate=regenerate, ) - with timer("Write PDF entries to JSONL file", logger): - # Process Each Entry from All Notes Files - entries = list(map(lambda entry: entry[1], entries_with_ids)) - jsonl_data = PdfToJsonl.convert_pdf_maps_to_jsonl(entries) - - # Compress JSONL formatted Data - compress_jsonl_data(jsonl_data, output_file) - - return entries_with_ids + return num_new_embeddings, num_deleted_embeddings @staticmethod def extract_pdf_entries(pdf_files): @@ -62,7 +63,7 @@ class PdfToJsonl(TextToJsonl): entry_to_location_map = [] for pdf_file in pdf_files: try: - # Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PyPDFLoader expects a file path + # Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path tmp_file = f"tmp_pdf_file.pdf" with open(f"{tmp_file}", "wb") as f: bytes = pdf_files[pdf_file] diff --git a/src/khoj/processor/plaintext/plaintext_to_jsonl.py b/src/khoj/processor/plaintext/plaintext_to_jsonl.py index 3acb656e..965a5a7b 100644 --- a/src/khoj/processor/plaintext/plaintext_to_jsonl.py +++ b/src/khoj/processor/plaintext/plaintext_to_jsonl.py @@ -4,22 +4,23 @@ from pathlib import Path from typing import List, Tuple # Internal Packages -from khoj.processor.text_to_jsonl import TextToJsonl +from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.helpers import timer -from khoj.utils.jsonl import compress_jsonl_data -from khoj.utils.rawconfig import Entry +from khoj.utils.rawconfig import Entry, TextContentConfig +from database.models import Embeddings, KhojUser, LocalPlaintextConfig logger = logging.getLogger(__name__) -class PlaintextToJsonl(TextToJsonl): +class PlaintextToJsonl(TextEmbeddings): + def __init__(self): + super().__init__() + # Define Functions def process( - self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True - ) -> List[Tuple[int, Entry]]: - output_file = self.config.compressed_jsonl - + self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False + ) -> Tuple[int, int]: if not full_corpus: deletion_file_names = set([file for file in files if files[file] == ""]) files_to_process = set(files) - deletion_file_names @@ -37,19 +38,17 @@ class PlaintextToJsonl(TextToJsonl): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): - entries_with_ids = TextToJsonl.mark_entries_for_update( - current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names + num_new_embeddings, num_deleted_embeddings = self.update_embeddings( + current_entries, + Embeddings.EmbeddingsType.PLAINTEXT, + key="compiled", + logger=logger, + deletion_filenames=deletion_file_names, + user=user, + regenerate=regenerate, ) - with timer("Write entries to JSONL file", logger): - # Process Each Entry from All Notes Files - entries = list(map(lambda entry: entry[1], entries_with_ids)) - plaintext_data = PlaintextToJsonl.convert_entries_to_jsonl(entries) - - # Compress JSONL formatted Data - compress_jsonl_data(plaintext_data, output_file) - - return entries_with_ids + return num_new_embeddings, num_deleted_embeddings @staticmethod def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]: diff --git a/src/khoj/processor/text_to_jsonl.py b/src/khoj/processor/text_to_jsonl.py index 98f5986f..c83c83b1 100644 --- a/src/khoj/processor/text_to_jsonl.py +++ b/src/khoj/processor/text_to_jsonl.py @@ -2,24 +2,33 @@ from abc import ABC, abstractmethod import hashlib import logging -from typing import Callable, List, Tuple, Set +import uuid +from tqdm import tqdm +from typing import Callable, List, Tuple, Set, Any from khoj.utils.helpers import timer + # Internal Packages -from khoj.utils.rawconfig import Entry, TextConfigBase +from khoj.utils.rawconfig import Entry +from khoj.processor.embeddings import EmbeddingsModel +from khoj.search_filter.date_filter import DateFilter +from database.models import KhojUser, Embeddings, EmbeddingsDates +from database.adapters import EmbeddingsAdapters logger = logging.getLogger(__name__) -class TextToJsonl(ABC): - def __init__(self, config: TextConfigBase): +class TextEmbeddings(ABC): + def __init__(self, config: Any = None): + self.embeddings_model = EmbeddingsModel() self.config = config + self.date_filter = DateFilter() @abstractmethod def process( - self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True - ) -> List[Tuple[int, Entry]]: + self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False + ) -> Tuple[int, int]: ... @staticmethod @@ -38,6 +47,7 @@ class TextToJsonl(ABC): # Drop long words instead of having entry truncated to maintain quality of entry processed by models compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length] + corpus_id = uuid.uuid4() # Split entry into chunks of max tokens for chunk_index in range(0, len(compiled_entry_words), max_tokens): @@ -57,11 +67,103 @@ class TextToJsonl(ABC): raw=entry.raw, heading=entry.heading, file=entry.file, + corpus_id=corpus_id, ) ) return chunked_entries + def update_embeddings( + self, + current_entries: List[Entry], + file_type: str, + key="compiled", + logger: logging.Logger = None, + deletion_filenames: Set[str] = None, + user: KhojUser = None, + regenerate: bool = False, + ): + with timer("Construct current entry hashes", logger): + hashes_by_file = dict[str, set[str]]() + current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries)) + hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) + for entry in tqdm(current_entries, desc="Hashing Entries"): + hashes_by_file.setdefault(entry.file, set()).add(TextEmbeddings.hash_func(key)(entry)) + + num_deleted_embeddings = 0 + with timer("Preparing dataset for regeneration", logger): + if regenerate: + logger.info(f"Deleting all embeddings for file type {file_type}") + num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type) + + num_new_embeddings = 0 + with timer("Identify hashes for adding new entries", logger): + for file in tqdm(hashes_by_file, desc="Processing file with hashed values"): + hashes_for_file = hashes_by_file[file] + hashes_to_process = set() + existing_entries = Embeddings.objects.filter( + user=user, hashed_value__in=hashes_for_file, file_type=file_type + ) + existing_entry_hashes = set([entry.hashed_value for entry in existing_entries]) + hashes_to_process = hashes_for_file - existing_entry_hashes + # for hashed_val in hashes_for_file: + # if not EmbeddingsAdapters.does_embedding_exist(user, hashed_val): + # hashes_to_process.add(hashed_val) + + entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process] + data_to_embed = [getattr(entry, key) for entry in entries_to_process] + embeddings = self.embeddings_model.embed_documents(data_to_embed) + + with timer("Update the database with new vector embeddings", logger): + embeddings_to_create = [] + for hashed_val, embedding in zip(hashes_to_process, embeddings): + entry = hash_to_current_entries[hashed_val] + embeddings_to_create.append( + Embeddings( + user=user, + embeddings=embedding, + raw=entry.raw, + compiled=entry.compiled, + heading=entry.heading, + file_path=entry.file, + file_type=file_type, + hashed_value=hashed_val, + corpus_id=entry.corpus_id, + ) + ) + new_embeddings = Embeddings.objects.bulk_create(embeddings_to_create) + num_new_embeddings += len(new_embeddings) + + dates_to_create = [] + with timer("Create new date associations for new embeddings", logger): + for embedding in new_embeddings: + dates = self.date_filter.extract_dates(embedding.raw) + for date in dates: + dates_to_create.append( + EmbeddingsDates( + date=date, + embeddings=embedding, + ) + ) + new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create) + if len(new_dates) > 0: + logger.info(f"Created {len(new_dates)} new date entries") + + with timer("Identify hashes for removed entries", logger): + for file in hashes_by_file: + existing_entry_hashes = EmbeddingsAdapters.get_existing_entry_hashes_by_file(user, file) + to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file] + num_deleted_embeddings += len(to_delete_entry_hashes) + EmbeddingsAdapters.delete_embedding_by_hash(user, hashed_values=list(to_delete_entry_hashes)) + + with timer("Identify hashes for deleting entries", logger): + if deletion_filenames is not None: + for file_path in deletion_filenames: + deleted_count = EmbeddingsAdapters.delete_embedding_by_file(user, file_path) + num_deleted_embeddings += deleted_count + + return num_new_embeddings, num_deleted_embeddings + @staticmethod def mark_entries_for_update( current_entries: List[Entry], @@ -72,11 +174,11 @@ class TextToJsonl(ABC): ): # Hash all current and previous entries to identify new entries with timer("Hash previous, current entries", logger): - current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries)) - previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries)) + current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries)) + previous_entry_hashes = list(map(TextEmbeddings.hash_func(key), previous_entries)) if deletion_filenames is not None: deletion_entries = [entry for entry in previous_entries if entry.file in deletion_filenames] - deletion_entry_hashes = list(map(TextToJsonl.hash_func(key), deletion_entries)) + deletion_entry_hashes = list(map(TextEmbeddings.hash_func(key), deletion_entries)) else: deletion_entry_hashes = [] diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 345429e8..7c3e3392 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -2,14 +2,15 @@ import concurrent.futures import math import time -import yaml import logging import json from typing import List, Optional, Union, Any +import asyncio # External Packages -from fastapi import APIRouter, HTTPException, Header, Request -from sentence_transformers import util +from fastapi import APIRouter, HTTPException, Header, Request, Depends +from starlette.authentication import requires +from asgiref.sync import sync_to_async # Internal Packages from khoj.configure import configure_processor, configure_server @@ -20,7 +21,6 @@ from khoj.search_filter.word_filter import WordFilter from khoj.utils.config import TextSearchModel from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions from khoj.utils.rawconfig import ( - ContentConfig, FullConfig, ProcessorConfig, SearchConfig, @@ -48,11 +48,74 @@ from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline from fastapi.requests import Request +from database import adapters +from database.adapters import EmbeddingsAdapters +from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser + # Initialize Router api = APIRouter() logger = logging.getLogger(__name__) + +def map_config_to_object(content_type: str): + if content_type == "org": + return LocalOrgConfig + if content_type == "markdown": + return LocalMarkdownConfig + if content_type == "pdf": + return LocalPdfConfig + if content_type == "plaintext": + return LocalPlaintextConfig + + +async def map_config_to_db(config: FullConfig, user: KhojUser): + if config.content_type: + if config.content_type.org: + await LocalOrgConfig.objects.filter(user=user).adelete() + await LocalOrgConfig.objects.acreate( + input_files=config.content_type.org.input_files, + input_filter=config.content_type.org.input_filter, + index_heading_entries=config.content_type.org.index_heading_entries, + user=user, + ) + if config.content_type.markdown: + await LocalMarkdownConfig.objects.filter(user=user).adelete() + await LocalMarkdownConfig.objects.acreate( + input_files=config.content_type.markdown.input_files, + input_filter=config.content_type.markdown.input_filter, + index_heading_entries=config.content_type.markdown.index_heading_entries, + user=user, + ) + if config.content_type.pdf: + await LocalPdfConfig.objects.filter(user=user).adelete() + await LocalPdfConfig.objects.acreate( + input_files=config.content_type.pdf.input_files, + input_filter=config.content_type.pdf.input_filter, + index_heading_entries=config.content_type.pdf.index_heading_entries, + user=user, + ) + if config.content_type.plaintext: + await LocalPlaintextConfig.objects.filter(user=user).adelete() + await LocalPlaintextConfig.objects.acreate( + input_files=config.content_type.plaintext.input_files, + input_filter=config.content_type.plaintext.input_filter, + index_heading_entries=config.content_type.plaintext.index_heading_entries, + user=user, + ) + if config.content_type.github: + await adapters.set_user_github_config( + user=user, + pat_token=config.content_type.github.pat_token, + repos=config.content_type.github.repos, + ) + if config.content_type.notion: + await adapters.set_notion_config( + user=user, + token=config.content_type.notion.token, + ) + + # If it's a demo instance, prevent updating any of the configuration. if not state.demo: @@ -64,7 +127,10 @@ if not state.demo: state.processor_config = configure_processor(state.config.processor) @api.get("/config/data", response_model=FullConfig) - def get_config_data(): + def get_config_data(request: Request): + user = request.user.object if request.user.is_authenticated else None + enabled_content = EmbeddingsAdapters.get_unique_file_types(user) + return state.config @api.post("/config/data") @@ -73,20 +139,19 @@ if not state.demo: updated_config: FullConfig, client: Optional[str] = None, ): - state.config = updated_config - with open(state.config_file, "w") as outfile: - yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile) - outfile.close() + user = request.user.object if request.user.is_authenticated else None + await map_config_to_db(updated_config, user) - configuration_update_metadata = dict() + configuration_update_metadata = {} + + enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user) if state.config.content_type is not None: - configuration_update_metadata["github"] = state.config.content_type.github is not None - configuration_update_metadata["notion"] = state.config.content_type.notion is not None - configuration_update_metadata["org"] = state.config.content_type.org is not None - configuration_update_metadata["pdf"] = state.config.content_type.pdf is not None - configuration_update_metadata["markdown"] = state.config.content_type.markdown is not None - configuration_update_metadata["plugins"] = state.config.content_type.plugins is not None + configuration_update_metadata["github"] = "github" in enabled_content + configuration_update_metadata["notion"] = "notion" in enabled_content + configuration_update_metadata["org"] = "org" in enabled_content + configuration_update_metadata["pdf"] = "pdf" in enabled_content + configuration_update_metadata["markdown"] = "markdown" in enabled_content if state.config.processor is not None: configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None @@ -101,6 +166,7 @@ if not state.demo: return state.config @api.post("/config/data/content_type/github", status_code=200) + @requires("authenticated") async def set_content_config_github_data( request: Request, updated_config: Union[GithubContentConfig, None], @@ -108,10 +174,13 @@ if not state.demo: ): _initialize_config() - if not state.config.content_type: - state.config.content_type = ContentConfig(**{"github": updated_config}) - else: - state.config.content_type.github = updated_config + user = request.user.object if request.user.is_authenticated else None + + await adapters.set_user_github_config( + user=user, + pat_token=updated_config.pat_token, + repos=updated_config.repos, + ) update_telemetry_state( request=request, @@ -121,11 +190,7 @@ if not state.demo: metadata={"content_type": "github"}, ) - try: - save_config_to_file_updated_state() - return {"status": "ok"} - except Exception as e: - return {"status": "error", "message": str(e)} + return {"status": "ok"} @api.post("/config/data/content_type/notion", status_code=200) async def set_content_config_notion_data( @@ -135,10 +200,12 @@ if not state.demo: ): _initialize_config() - if not state.config.content_type: - state.config.content_type = ContentConfig(**{"notion": updated_config}) - else: - state.config.content_type.notion = updated_config + user = request.user.object if request.user.is_authenticated else None + + await adapters.set_notion_config( + user=user, + token=updated_config.token, + ) update_telemetry_state( request=request, @@ -148,11 +215,7 @@ if not state.demo: metadata={"content_type": "notion"}, ) - try: - save_config_to_file_updated_state() - return {"status": "ok"} - except Exception as e: - return {"status": "error", "message": str(e)} + return {"status": "ok"} @api.post("/delete/config/data/content_type/{content_type}", status_code=200) async def remove_content_config_data( @@ -160,8 +223,7 @@ if not state.demo: content_type: str, client: Optional[str] = None, ): - if not state.config or not state.config.content_type: - return {"status": "ok"} + user = request.user.object if request.user.is_authenticated else None update_telemetry_state( request=request, @@ -171,31 +233,13 @@ if not state.demo: metadata={"content_type": content_type}, ) - if state.config.content_type: - state.config.content_type[content_type] = None + content_object = map_config_to_object(content_type) + await content_object.objects.filter(user=user).adelete() + await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type) - if content_type == "github": - state.content_index.github = None - elif content_type == "notion": - state.content_index.notion = None - elif content_type == "plugins": - state.content_index.plugins = None - elif content_type == "pdf": - state.content_index.pdf = None - elif content_type == "markdown": - state.content_index.markdown = None - elif content_type == "org": - state.content_index.org = None - elif content_type == "plaintext": - state.content_index.plaintext = None - else: - logger.warning(f"Request to delete unknown content type: {content_type} via API") + enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user) - try: - save_config_to_file_updated_state() - return {"status": "ok"} - except Exception as e: - return {"status": "error", "message": str(e)} + return {"status": "ok"} @api.post("/delete/config/data/processor/conversation/openai", status_code=200) async def remove_processor_conversation_config_data( @@ -228,6 +272,7 @@ if not state.demo: return {"status": "error", "message": str(e)} @api.post("/config/data/content_type/{content_type}", status_code=200) + # @requires("authenticated") async def set_content_config_data( request: Request, content_type: str, @@ -236,10 +281,10 @@ if not state.demo: ): _initialize_config() - if not state.config.content_type: - state.config.content_type = ContentConfig(**{content_type: updated_config}) - else: - state.config.content_type[content_type] = updated_config + user = request.user.object if request.user.is_authenticated else None + + content_object = map_config_to_object(content_type) + await adapters.set_text_content_config(user, content_object, updated_config) update_telemetry_state( request=request, @@ -249,11 +294,7 @@ if not state.demo: metadata={"content_type": content_type}, ) - try: - save_config_to_file_updated_state() - return {"status": "ok"} - except Exception as e: - return {"status": "error", "message": str(e)} + return {"status": "ok"} @api.post("/config/data/processor/conversation/openai", status_code=200) async def set_processor_openai_config_data( @@ -337,24 +378,23 @@ def get_default_config_data(): @api.get("/config/types", response_model=List[str]) -def get_config_types(): - """Get configured content types""" - if state.config is None or state.config.content_type is None: - raise HTTPException( - status_code=500, - detail="Content types not configured. Configure at least one content type on server and restart it.", - ) +def get_config_types( + request: Request, +): + user = request.user.object if request.user.is_authenticated else None + + enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user) + + configured_content_types = list(enabled_file_types) + + if state.config and state.config.content_type: + for ctype in state.config.content_type.dict(exclude_none=True): + configured_content_types.append(ctype) - configured_content_types = state.config.content_type.dict(exclude_none=True) return [ search_type.value for search_type in SearchType - if ( - search_type.value in configured_content_types - and getattr(state.content_index, search_type.value) is not None - ) - or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"]) - or search_type == SearchType.All + if (search_type.value in configured_content_types) or search_type == SearchType.All ] @@ -372,6 +412,7 @@ async def search( referer: Optional[str] = Header(None), host: Optional[str] = Header(None), ): + user = request.user.object if request.user.is_authenticated else None start_time = time.time() # Run validation checks @@ -390,10 +431,11 @@ async def search( search_futures: List[concurrent.futures.Future] = [] # return cached results, if available - query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}" - if query_cache_key in state.query_cache: - logger.debug(f"Return response from query cache") - return state.query_cache[query_cache_key] + if user: + query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}" + if query_cache_key in state.query_cache[user.uuid]: + logger.debug(f"Return response from query cache") + return state.query_cache[user.uuid][query_cache_key] # Encode query with filter terms removed defiltered_query = user_query @@ -407,84 +449,31 @@ async def search( ] if text_search_models: with timer("Encoding query took", logger=logger): - encoded_asymmetric_query = util.normalize_embeddings( - text_search_models[0].bi_encoder.encode( - [defiltered_query], - convert_to_tensor=True, - device=state.device, - ) - ) + encoded_asymmetric_query = state.embeddings_model.embed_query(defiltered_query) with concurrent.futures.ThreadPoolExecutor() as executor: - if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search: - # query org-mode notes - search_futures += [ - executor.submit( - text_search.query, - user_query, - state.search_models.text_search, - state.content_index.org, - question_embedding=encoded_asymmetric_query, - rank_results=r or False, - score_threshold=score_threshold, - dedupe=dedupe or True, - ) - ] - - if ( - (t == SearchType.Markdown or t == SearchType.All) - and state.content_index.markdown - and state.search_models.text_search - ): + if t in [ + SearchType.All, + SearchType.Org, + SearchType.Markdown, + SearchType.Github, + SearchType.Notion, + SearchType.Plaintext, + ]: # query markdown notes search_futures += [ executor.submit( text_search.query, + user, user_query, - state.search_models.text_search, - state.content_index.markdown, + t, question_embedding=encoded_asymmetric_query, rank_results=r or False, score_threshold=score_threshold, - dedupe=dedupe or True, ) ] - if ( - (t == SearchType.Github or t == SearchType.All) - and state.content_index.github - and state.search_models.text_search - ): - # query github issues - search_futures += [ - executor.submit( - text_search.query, - user_query, - state.search_models.text_search, - state.content_index.github, - question_embedding=encoded_asymmetric_query, - rank_results=r or False, - score_threshold=score_threshold, - dedupe=dedupe or True, - ) - ] - - if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search: - # query pdf files - search_futures += [ - executor.submit( - text_search.query, - user_query, - state.search_models.text_search, - state.content_index.pdf, - question_embedding=encoded_asymmetric_query, - rank_results=r or False, - score_threshold=score_threshold, - dedupe=dedupe or True, - ) - ] - - if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search: + elif (t == SearchType.Image) and state.content_index.image and state.search_models.image_search: # query images search_futures += [ executor.submit( @@ -497,70 +486,6 @@ async def search( ) ] - if ( - (t == SearchType.All or t in SearchType) - and state.content_index.plugins - and state.search_models.plugin_search - ): - # query specified plugin type - # Get plugin content, search model for specified search type, or the first one if none specified - plugin_search = state.search_models.plugin_search.get(t.value) or next( - iter(state.search_models.plugin_search.values()) - ) - plugin_content = state.content_index.plugins.get(t.value) or next( - iter(state.content_index.plugins.values()) - ) - search_futures += [ - executor.submit( - text_search.query, - user_query, - plugin_search, - plugin_content, - question_embedding=encoded_asymmetric_query, - rank_results=r or False, - score_threshold=score_threshold, - dedupe=dedupe or True, - ) - ] - - if ( - (t == SearchType.Notion or t == SearchType.All) - and state.content_index.notion - and state.search_models.text_search - ): - # query notion pages - search_futures += [ - executor.submit( - text_search.query, - user_query, - state.search_models.text_search, - state.content_index.notion, - question_embedding=encoded_asymmetric_query, - rank_results=r or False, - score_threshold=score_threshold, - dedupe=dedupe or True, - ) - ] - - if ( - (t == SearchType.Plaintext or t == SearchType.All) - and state.content_index.plaintext - and state.search_models.text_search - ): - # query plaintext files - search_futures += [ - executor.submit( - text_search.query, - user_query, - state.search_models.text_search, - state.content_index.plaintext, - question_embedding=encoded_asymmetric_query, - rank_results=r or False, - score_threshold=score_threshold, - dedupe=dedupe or True, - ) - ] - # Query across each requested content types in parallel with timer("Query took", logger): for search_future in concurrent.futures.as_completed(search_futures): @@ -576,15 +501,19 @@ async def search( count=results_count, ) else: - hits, entries = await search_future.result() + hits = await search_future.result() # Collate results - results += text_search.collate_results(hits, entries, results_count) + results += text_search.collate_results(hits, dedupe=dedupe) - # Sort results across all content types and take top results - results = sorted(results, key=lambda x: float(x.score), reverse=True)[:results_count] + if r: + results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count] + else: + # Sort results across all content types and take top results + results = sorted(results, key=lambda x: float(x.score))[:results_count] # Cache results - state.query_cache[query_cache_key] = results + if user: + state.query_cache[user.uuid][query_cache_key] = results update_telemetry_state( request=request, @@ -596,8 +525,6 @@ async def search( host=host, ) - state.previous_query = user_query - end_time = time.time() logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds") @@ -614,12 +541,13 @@ def update( referer: Optional[str] = Header(None), host: Optional[str] = Header(None), ): + user = request.user.object if request.user.is_authenticated else None if not state.config: error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}." logger.warning(error_msg) raise HTTPException(status_code=500, detail=error_msg) try: - configure_server(state.config, regenerate=force, search_type=t) + configure_server(state.config, regenerate=force, search_type=t, user=user) except Exception as e: error_msg = f"🚨 Failed to update server via API: {e}" logger.error(error_msg, exc_info=True) @@ -774,6 +702,7 @@ async def extract_references_and_questions( n: int, conversation_type: ConversationCommand = ConversationCommand.Default, ): + user = request.user.object if request.user.is_authenticated else None # Load Conversation History meta_log = state.processor_config.conversation.meta_log @@ -781,7 +710,7 @@ async def extract_references_and_questions( compiled_references: List[Any] = [] inferred_queries: List[str] = [] - if state.content_index is None: + if not EmbeddingsAdapters.user_has_embeddings(user=user): logger.warning( "No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes." ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 6b42f29c..be9e8700 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -10,6 +10,7 @@ from khoj.utils.helpers import ConversationCommand, timer, log_telemetry from khoj.processor.conversation.openai.gpt import converse from khoj.processor.conversation.gpt4all.chat_model import converse_offline from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator +from database.models import KhojUser logger = logging.getLogger(__name__) @@ -40,11 +41,13 @@ def update_telemetry_state( host: Optional[str] = None, metadata: Optional[dict] = None, ): + user: KhojUser = request.user.object if request.user.is_authenticated else None user_state = { "client_host": request.client.host if request.client else None, "user_agent": user_agent or "unknown", "referer": referer or "unknown", "host": host or "unknown", + "server_id": str(user.uuid) if user else None, } if metadata: diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index a9656050..c2ef04ff 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -1,6 +1,7 @@ # Standard Packages import logging from typing import Optional, Union, Dict +import asyncio # External Packages from fastapi import APIRouter, HTTPException, Header, Request, Response, UploadFile @@ -9,31 +10,30 @@ from khoj.routers.helpers import update_telemetry_state # Internal Packages from khoj.utils import state, constants -from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl from khoj.processor.github.github_to_jsonl import GithubToJsonl from khoj.processor.notion.notion_to_jsonl import NotionToJsonl from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl -from khoj.utils.rawconfig import ContentConfig, TextContentConfig from khoj.search_type import text_search, image_search from khoj.utils.yaml import save_config_to_file_updated_state from khoj.utils.config import SearchModels -from khoj.utils.constants import default_config from khoj.utils.helpers import LRU, get_file_type from khoj.utils.rawconfig import ( ContentConfig, FullConfig, SearchConfig, ) -from khoj.search_filter.date_filter import DateFilter -from khoj.search_filter.word_filter import WordFilter -from khoj.search_filter.file_filter import FileFilter from khoj.utils.config import ( ContentIndex, SearchModels, ) +from database.models import ( + KhojUser, + GithubConfig, + NotionConfig, +) logger = logging.getLogger(__name__) @@ -68,14 +68,14 @@ async def update( referer: Optional[str] = Header(None), host: Optional[str] = Header(None), ): + user = request.user.object if request.user.is_authenticated else None if x_api_key != "secret": raise HTTPException(status_code=401, detail="Invalid API Key") - state.config_lock.acquire() try: logger.info(f"📬 Updating content index via API call by {client} client") org_files: Dict[str, str] = {} markdown_files: Dict[str, str] = {} - pdf_files: Dict[str, str] = {} + pdf_files: Dict[str, bytes] = {} plaintext_files: Dict[str, str] = {} for file in files: @@ -86,7 +86,7 @@ async def update( elif file_type == "markdown": dict_to_update = markdown_files elif file_type == "pdf": - dict_to_update = pdf_files + dict_to_update = pdf_files # type: ignore elif file_type == "plaintext": dict_to_update = plaintext_files @@ -120,30 +120,31 @@ async def update( github=None, notion=None, plaintext=None, - plugins=None, ) state.config.content_type = default_content_config save_config_to_file_updated_state() configure_search(state.search_models, state.config.search_type) # Extract required fields from config - state.content_index = configure_content( + loop = asyncio.get_event_loop() + state.content_index = await loop.run_in_executor( + None, + configure_content, state.content_index, state.config.content_type, indexer_input.dict(), state.search_models, - regenerate=force, - t=t, - full_corpus=False, + force, + t, + False, + user, ) - + logger.info(f"Finished processing batch indexing request") except Exception as e: logger.error( f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}", exc_info=True, ) - finally: - state.config_lock.release() update_telemetry_state( request=request, @@ -167,11 +168,6 @@ def configure_search(search_models: SearchModels, search_config: Optional[Search if search_models is None: search_models = SearchModels() - # Initialize Search Models - if search_config.asymmetric: - logger.info("🔍 📜 Setting up text search model") - search_models.text_search = text_search.initialize_model(search_config.asymmetric) - if search_config.image: logger.info("🔍 🌄 Setting up image search model") search_models.image_search = image_search.initialize_model(search_config.image) @@ -187,16 +183,9 @@ def configure_content( regenerate: bool = False, t: Optional[Union[state.SearchType, str]] = None, full_corpus: bool = True, + user: KhojUser = None, ) -> Optional[ContentIndex]: - def has_valid_text_config(config: TextContentConfig): - return config.input_files or config.input_filter - - # Run Validation Checks - if content_config is None: - logger.warning("🚨 No Content configuration available.") - return None - if content_index is None: - content_index = ContentIndex() + content_index = ContentIndex() if t in [type.value for type in state.SearchType]: t = state.SearchType(t).value @@ -209,59 +198,30 @@ def configure_content( try: # Initialize Org Notes Search - if ( - (t == None or t == state.SearchType.Org.value) - and ((content_config.org and has_valid_text_config(content_config.org)) or files["org"]) - and search_models.text_search - ): - if content_config.org == None: - logger.info("🦄 No configuration for orgmode notes. Using default configuration.") - default_configuration = default_config["content-type"]["org"] # type: ignore - content_config.org = TextContentConfig( - compressed_jsonl=default_configuration["compressed-jsonl"], - embeddings_file=default_configuration["embeddings-file"], - ) - + if (t == None or t == state.SearchType.Org.value) and files["org"]: logger.info("🦄 Setting up search for orgmode notes") # Extract Entries, Generate Notes Embeddings - content_index.org = text_search.setup( + text_search.setup( OrgToJsonl, files.get("org"), - content_config.org, - search_models.text_search.bi_encoder, regenerate=regenerate, - filters=[DateFilter(), WordFilter(), FileFilter()], full_corpus=full_corpus, + user=user, ) except Exception as e: logger.error(f"🚨 Failed to setup org: {e}", exc_info=True) try: # Initialize Markdown Search - if ( - (t == None or t == state.SearchType.Markdown.value) - and ((content_config.markdown and has_valid_text_config(content_config.markdown)) or files["markdown"]) - and search_models.text_search - and files["markdown"] - ): - if content_config.markdown == None: - logger.info("💎 No configuration for markdown notes. Using default configuration.") - default_configuration = default_config["content-type"]["markdown"] # type: ignore - content_config.markdown = TextContentConfig( - compressed_jsonl=default_configuration["compressed-jsonl"], - embeddings_file=default_configuration["embeddings-file"], - ) - + if (t == None or t == state.SearchType.Markdown.value) and files["markdown"]: logger.info("💎 Setting up search for markdown notes") # Extract Entries, Generate Markdown Embeddings - content_index.markdown = text_search.setup( + text_search.setup( MarkdownToJsonl, files.get("markdown"), - content_config.markdown, - search_models.text_search.bi_encoder, regenerate=regenerate, - filters=[DateFilter(), WordFilter(), FileFilter()], full_corpus=full_corpus, + user=user, ) except Exception as e: @@ -269,30 +229,15 @@ def configure_content( try: # Initialize PDF Search - if ( - (t == None or t == state.SearchType.Pdf.value) - and ((content_config.pdf and has_valid_text_config(content_config.pdf)) or files["pdf"]) - and search_models.text_search - and files["pdf"] - ): - if content_config.pdf == None: - logger.info("🖨️ No configuration for pdf notes. Using default configuration.") - default_configuration = default_config["content-type"]["pdf"] # type: ignore - content_config.pdf = TextContentConfig( - compressed_jsonl=default_configuration["compressed-jsonl"], - embeddings_file=default_configuration["embeddings-file"], - ) - + if (t == None or t == state.SearchType.Pdf.value) and files["pdf"]: logger.info("🖨️ Setting up search for pdf") # Extract Entries, Generate PDF Embeddings - content_index.pdf = text_search.setup( + text_search.setup( PdfToJsonl, files.get("pdf"), - content_config.pdf, - search_models.text_search.bi_encoder, regenerate=regenerate, - filters=[DateFilter(), WordFilter(), FileFilter()], full_corpus=full_corpus, + user=user, ) except Exception as e: @@ -300,30 +245,15 @@ def configure_content( try: # Initialize Plaintext Search - if ( - (t == None or t == state.SearchType.Plaintext.value) - and ((content_config.plaintext and has_valid_text_config(content_config.plaintext)) or files["plaintext"]) - and search_models.text_search - and files["plaintext"] - ): - if content_config.plaintext == None: - logger.info("📄 No configuration for plaintext notes. Using default configuration.") - default_configuration = default_config["content-type"]["plaintext"] # type: ignore - content_config.plaintext = TextContentConfig( - compressed_jsonl=default_configuration["compressed-jsonl"], - embeddings_file=default_configuration["embeddings-file"], - ) - + if (t == None or t == state.SearchType.Plaintext.value) and files["plaintext"]: logger.info("📄 Setting up search for plaintext") # Extract Entries, Generate Plaintext Embeddings - content_index.plaintext = text_search.setup( + text_search.setup( PlaintextToJsonl, files.get("plaintext"), - content_config.plaintext, - search_models.text_search.bi_encoder, regenerate=regenerate, - filters=[DateFilter(), WordFilter(), FileFilter()], full_corpus=full_corpus, + user=user, ) except Exception as e: @@ -331,7 +261,12 @@ def configure_content( try: # Initialize Image Search - if (t == None or t == state.SearchType.Image.value) and content_config.image and search_models.image_search: + if ( + (t == None or t == state.SearchType.Image.value) + and content_config + and content_config.image + and search_models.image_search + ): logger.info("🌄 Setting up search for images") # Extract Entries, Generate Image Embeddings content_index.image = image_search.setup( @@ -342,17 +277,17 @@ def configure_content( logger.error(f"🚨 Failed to setup images: {e}", exc_info=True) try: - if (t == None or t == state.SearchType.Github.value) and content_config.github and search_models.text_search: + github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() + if (t == None or t == state.SearchType.Github.value) and github_config is not None: logger.info("🐙 Setting up search for github") # Extract Entries, Generate Github Embeddings - content_index.github = text_search.setup( + text_search.setup( GithubToJsonl, None, - content_config.github, - search_models.text_search.bi_encoder, regenerate=regenerate, - filters=[DateFilter(), WordFilter(), FileFilter()], full_corpus=full_corpus, + user=user, + config=github_config, ) except Exception as e: @@ -360,42 +295,24 @@ def configure_content( try: # Initialize Notion Search - if (t == None or t in state.SearchType.Notion.value) and content_config.notion and search_models.text_search: + notion_config = NotionConfig.objects.filter(user=user).first() + if (t == None or t in state.SearchType.Notion.value) and notion_config: logger.info("🔌 Setting up search for notion") - content_index.notion = text_search.setup( + text_search.setup( NotionToJsonl, None, - content_config.notion, - search_models.text_search.bi_encoder, regenerate=regenerate, - filters=[DateFilter(), WordFilter(), FileFilter()], full_corpus=full_corpus, + user=user, + config=notion_config, ) except Exception as e: logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True) - try: - # Initialize External Plugin Search - if t == None and content_config.plugins and search_models.text_search: - logger.info("🔌 Setting up search for plugins") - content_index.plugins = {} - for plugin_type, plugin_config in content_config.plugins.items(): - content_index.plugins[plugin_type] = text_search.setup( - JsonlToJsonl, - None, - plugin_config, - search_models.text_search.bi_encoder, - regenerate=regenerate, - filters=[DateFilter(), WordFilter(), FileFilter()], - full_corpus=full_corpus, - ) - - except Exception as e: - logger.error(f"🚨 Failed to setup Plugin: {e}", exc_info=True) - # Invalidate Query Cache - state.query_cache = LRU() + if user: + state.query_cache[user.uuid] = LRU() return content_index @@ -412,44 +329,9 @@ def load_content( if content_index is None: content_index = ContentIndex() - if content_config.org: - logger.info("🦄 Loading orgmode notes") - content_index.org = text_search.load(content_config.org, filters=[DateFilter(), WordFilter(), FileFilter()]) - if content_config.markdown: - logger.info("💎 Loading markdown notes") - content_index.markdown = text_search.load( - content_config.markdown, filters=[DateFilter(), WordFilter(), FileFilter()] - ) - if content_config.pdf: - logger.info("🖨️ Loading pdf") - content_index.pdf = text_search.load(content_config.pdf, filters=[DateFilter(), WordFilter(), FileFilter()]) - if content_config.plaintext: - logger.info("📄 Loading plaintext") - content_index.plaintext = text_search.load( - content_config.plaintext, filters=[DateFilter(), WordFilter(), FileFilter()] - ) if content_config.image: logger.info("🌄 Loading images") content_index.image = image_search.setup( content_config.image, search_models.image_search.image_encoder, regenerate=False ) - if content_config.github: - logger.info("🐙 Loading github") - content_index.github = text_search.load( - content_config.github, filters=[DateFilter(), WordFilter(), FileFilter()] - ) - if content_config.notion: - logger.info("🔌 Loading notion") - content_index.notion = text_search.load( - content_config.notion, filters=[DateFilter(), WordFilter(), FileFilter()] - ) - if content_config.plugins: - logger.info("🔌 Loading plugins") - content_index.plugins = {} - for plugin_type, plugin_config in content_config.plugins.items(): - content_index.plugins[plugin_type] = text_search.load( - plugin_config, filters=[DateFilter(), WordFilter(), FileFilter()] - ) - - state.query_cache = LRU() return content_index diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 492a263c..6c79e061 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -3,10 +3,20 @@ from fastapi import APIRouter from fastapi import Request from fastapi.responses import HTMLResponse, FileResponse from fastapi.templating import Jinja2Templates -from khoj.utils.rawconfig import TextContentConfig, OpenAIProcessorConfig, FullConfig +from starlette.authentication import requires +from khoj.utils.rawconfig import ( + TextContentConfig, + OpenAIProcessorConfig, + FullConfig, + GithubContentConfig, + GithubRepoConfig, + NotionContentConfig, +) # Internal Packages from khoj.utils import constants, state +from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config +from database.models import KhojUser, LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig import json @@ -29,10 +39,23 @@ def chat_page(request: Request): return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo}) +def map_config_to_object(content_type: str): + if content_type == "org": + return LocalOrgConfig + if content_type == "markdown": + return LocalMarkdownConfig + if content_type == "pdf": + return LocalPdfConfig + if content_type == "plaintext": + return LocalPlaintextConfig + + if not state.demo: @web_client.get("/config", response_class=HTMLResponse) def config_page(request: Request): + user = request.user.object if request.user.is_authenticated else None + enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all()) default_full_config = FullConfig( content_type=None, search_type=None, @@ -41,13 +64,13 @@ if not state.demo: current_config = state.config or json.loads(default_full_config.json()) successfully_configured = { - "pdf": False, - "markdown": False, - "org": False, + "pdf": ("pdf" in enabled_content), + "markdown": ("markdown" in enabled_content), + "org": ("org" in enabled_content), "image": False, - "github": False, - "notion": False, - "plaintext": False, + "github": ("github" in enabled_content), + "notion": ("notion" in enabled_content), + "plaintext": ("plaintext" in enabled_content), "enable_offline_model": False, "conversation_openai": False, "conversation_gpt4all": False, @@ -56,13 +79,7 @@ if not state.demo: if state.content_index: successfully_configured.update( { - "pdf": state.content_index.pdf is not None, - "markdown": state.content_index.markdown is not None, - "org": state.content_index.org is not None, "image": state.content_index.image is not None, - "github": state.content_index.github is not None, - "notion": state.content_index.notion is not None, - "plaintext": state.content_index.plaintext is not None, } ) @@ -84,22 +101,29 @@ if not state.demo: ) @web_client.get("/config/content_type/github", response_class=HTMLResponse) + @requires(["authenticated"]) def github_config_page(request: Request): - default_copy = constants.default_config.copy() - default_github = default_copy["content-type"]["github"] # type: ignore + user = request.user.object if request.user.is_authenticated else None + current_github_config = get_user_github_config(user) - default_config = TextContentConfig( - compressed_jsonl=default_github["compressed-jsonl"], - embeddings_file=default_github["embeddings-file"], - ) - - current_config = ( - state.config.content_type.github - if state.config and state.config.content_type and state.config.content_type.github - else default_config - ) - - current_config = json.loads(current_config.json()) + if current_github_config: + raw_repos = current_github_config.githubrepoconfig.all() + repos = [] + for repo in raw_repos: + repos.append( + GithubRepoConfig( + name=repo.name, + owner=repo.owner, + branch=repo.branch, + ) + ) + current_config = GithubContentConfig( + pat_token=current_github_config.pat_token, + repos=repos, + ) + current_config = json.loads(current_config.json()) + else: + current_config = {} # type: ignore return templates.TemplateResponse( "content_type_github_input.html", context={"request": request, "current_config": current_config} @@ -107,18 +131,11 @@ if not state.demo: @web_client.get("/config/content_type/notion", response_class=HTMLResponse) def notion_config_page(request: Request): - default_copy = constants.default_config.copy() - default_notion = default_copy["content-type"]["notion"] # type: ignore + user = request.user.object if request.user.is_authenticated else None + current_notion_config = get_user_notion_config(user) - default_config = TextContentConfig( - compressed_jsonl=default_notion["compressed-jsonl"], - embeddings_file=default_notion["embeddings-file"], - ) - - current_config = ( - state.config.content_type.notion - if state.config and state.config.content_type and state.config.content_type.notion - else default_config + current_config = NotionContentConfig( + token=current_notion_config.token if current_notion_config else "", ) current_config = json.loads(current_config.json()) @@ -132,18 +149,16 @@ if not state.demo: if content_type not in VALID_TEXT_CONTENT_TYPES: return templates.TemplateResponse("config.html", context={"request": request}) - default_copy = constants.default_config.copy() - default_content_type = default_copy["content-type"][content_type] # type: ignore + object = map_config_to_object(content_type) + user = request.user.object if request.user.is_authenticated else None + config = object.objects.filter(user=user).first() + if config == None: + config = object.objects.create(user=user) - default_config = TextContentConfig( - compressed_jsonl=default_content_type["compressed-jsonl"], - embeddings_file=default_content_type["embeddings-file"], - ) - - current_config = ( - state.config.content_type[content_type] - if state.config and state.config.content_type and state.config.content_type[content_type] # type: ignore - else default_config + current_config = TextContentConfig( + input_files=config.input_files, + input_filter=config.input_filter, + index_heading_entries=config.index_heading_entries, ) current_config = json.loads(current_config.json()) diff --git a/src/khoj/search_filter/base_filter.py b/src/khoj/search_filter/base_filter.py index 470f7341..ae596587 100644 --- a/src/khoj/search_filter/base_filter.py +++ b/src/khoj/search_filter/base_filter.py @@ -1,16 +1,9 @@ # Standard Packages from abc import ABC, abstractmethod -from typing import List, Set, Tuple - -# Internal Packages -from khoj.utils.rawconfig import Entry +from typing import List class BaseFilter(ABC): - @abstractmethod - def load(self, entries: List[Entry], *args, **kwargs): - ... - @abstractmethod def get_filter_terms(self, query: str) -> List[str]: ... @@ -18,10 +11,6 @@ class BaseFilter(ABC): def can_filter(self, raw_query: str) -> bool: return len(self.get_filter_terms(raw_query)) > 0 - @abstractmethod - def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]: - ... - @abstractmethod def defilter(self, query: str) -> str: ... diff --git a/src/khoj/search_filter/date_filter.py b/src/khoj/search_filter/date_filter.py index 39e7bec3..88c70101 100644 --- a/src/khoj/search_filter/date_filter.py +++ b/src/khoj/search_filter/date_filter.py @@ -25,72 +25,42 @@ class DateFilter(BaseFilter): # - dt>="last week" # - dt:"2 years ago" date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']" + raw_date_regex = r"\d{4}-\d{2}-\d{2}" def __init__(self, entry_key="compiled"): self.entry_key = entry_key self.date_to_entry_ids = defaultdict(set) self.cache = LRU() - def load(self, entries, *args, **kwargs): - with timer("Created date filter index", logger): - for id, entry in enumerate(entries): - # Extract dates from entry - for date_in_entry_string in re.findall(r"\d{4}-\d{2}-\d{2}", getattr(entry, self.entry_key)): - # Convert date string in entry to unix timestamp - try: - date_in_entry = datetime.strptime(date_in_entry_string, "%Y-%m-%d").timestamp() - except ValueError: - continue - except OSError: - logger.debug(f"OSError: Ignoring unprocessable date in entry: {date_in_entry_string}") - continue - self.date_to_entry_ids[date_in_entry].add(id) + def extract_dates(self, content): + pattern_matched_dates = re.findall(self.raw_date_regex, content) + + # Filter down to valid dates + valid_dates = [] + for date_str in pattern_matched_dates: + try: + valid_dates.append(datetime.strptime(date_str, "%Y-%m-%d")) + except ValueError: + continue + + return valid_dates def get_filter_terms(self, query: str) -> List[str]: "Get all filter terms in query" return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)] + def get_query_date_range(self, query) -> List: + with timer("Extract date range to filter from query", logger): + query_daterange = self.extract_date_range(query) + + return query_daterange + def defilter(self, query): # remove date range filter from query query = re.sub(rf"\s+{self.date_regex}", " ", query) query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces return query - def apply(self, query, entries): - "Find entries containing any dates that fall within date range specified in query" - # extract date range specified in date filter of query - with timer("Extract date range to filter from query", logger): - query_daterange = self.extract_date_range(query) - - # if no date in query, return all entries - if query_daterange == []: - return query, set(range(len(entries))) - - query = self.defilter(query) - - # return results from cache if exists - cache_key = tuple(query_daterange) - if cache_key in self.cache: - logger.debug(f"Return date filter results from cache") - entries_to_include = self.cache[cache_key] - return query, entries_to_include - - if not self.date_to_entry_ids: - self.load(entries) - - # find entries containing any dates that fall with date range specified in query - with timer("Mark entries satisfying filter", logger): - entries_to_include = set() - for date_in_entry in self.date_to_entry_ids.keys(): - # Check if date in entry is within date range specified in query - if query_daterange[0] <= date_in_entry < query_daterange[1]: - entries_to_include |= self.date_to_entry_ids[date_in_entry] - - # cache results - self.cache[cache_key] = entries_to_include - - return query, entries_to_include - def extract_date_range(self, query): # find date range filter in query date_range_matches = re.findall(self.date_regex, query) @@ -138,6 +108,15 @@ class DateFilter(BaseFilter): if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]: return [] else: + # If the first element is 0, replace it with None + + if effective_date_range[0] == 0: + effective_date_range[0] = None + + # If the second element is inf, replace it with None + if effective_date_range[1] == inf: + effective_date_range[1] = None + return effective_date_range def parse(self, date_str, relative_base=None): diff --git a/src/khoj/search_filter/file_filter.py b/src/khoj/search_filter/file_filter.py index 420bf9e7..291838ea 100644 --- a/src/khoj/search_filter/file_filter.py +++ b/src/khoj/search_filter/file_filter.py @@ -21,62 +21,13 @@ class FileFilter(BaseFilter): self.file_to_entry_map = defaultdict(set) self.cache = LRU() - def load(self, entries, *args, **kwargs): - with timer("Created file filter index", logger): - for id, entry in enumerate(entries): - self.file_to_entry_map[getattr(entry, self.entry_key)].add(id) - def get_filter_terms(self, query: str) -> List[str]: "Get all filter terms in query" - return [f'file:"{term}"' for term in re.findall(self.file_filter_regex, query)] + return [f"{self.convert_to_regex(term)}" for term in re.findall(self.file_filter_regex, query)] + + def convert_to_regex(self, file_filter: str) -> str: + "Convert file filter to regex" + return file_filter.replace(".", r"\.").replace("*", r".*") def defilter(self, query: str) -> str: return re.sub(self.file_filter_regex, "", query).strip() - - def apply(self, query, entries): - # Extract file filters from raw query - with timer("Extract files_to_search from query", logger): - raw_files_to_search = re.findall(self.file_filter_regex, query) - if not raw_files_to_search: - return query, set(range(len(entries))) - - # Convert simple file filters with no path separator into regex - # e.g. "file:notes.org" -> "file:.*notes.org" - files_to_search = [] - for file in sorted(raw_files_to_search): - if "/" not in file and "\\" not in file and "*" not in file: - files_to_search += [f"*{file}"] - else: - files_to_search += [file] - - # Remove filter terms from original query - query = self.defilter(query) - - # Return item from cache if exists - cache_key = tuple(files_to_search) - if cache_key in self.cache: - logger.debug(f"Return file filter results from cache") - included_entry_indices = self.cache[cache_key] - return query, included_entry_indices - - if not self.file_to_entry_map: - self.load(entries, regenerate=False) - - # Mark entries that contain any blocked_words for exclusion - with timer("Mark entries satisfying filter", logger): - included_entry_indices = set.union( - *[ - self.file_to_entry_map[entry_file] - for entry_file in self.file_to_entry_map.keys() - for search_file in files_to_search - if fnmatch.fnmatch(entry_file, search_file) - ], - set(), - ) - if not included_entry_indices: - return query, {} - - # Cache results - self.cache[cache_key] = included_entry_indices - - return query, included_entry_indices diff --git a/src/khoj/search_filter/word_filter.py b/src/khoj/search_filter/word_filter.py index ebf64b34..b2053dbe 100644 --- a/src/khoj/search_filter/word_filter.py +++ b/src/khoj/search_filter/word_filter.py @@ -6,7 +6,7 @@ from typing import List # Internal Packages from khoj.search_filter.base_filter import BaseFilter -from khoj.utils.helpers import LRU, timer +from khoj.utils.helpers import LRU logger = logging.getLogger(__name__) @@ -22,21 +22,6 @@ class WordFilter(BaseFilter): self.word_to_entry_index = defaultdict(set) self.cache = LRU() - def load(self, entries, *args, **kwargs): - with timer("Created word filter index", logger): - self.cache = {} # Clear cache on filter (re-)load - entry_splitter = ( - r",|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'" - ) - # Create map of words to entries they exist in - for entry_index, entry in enumerate(entries): - for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()): - if word == "": - continue - self.word_to_entry_index[word].add(entry_index) - - return self.word_to_entry_index - def get_filter_terms(self, query: str) -> List[str]: "Get all filter terms in query" required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)] @@ -45,47 +30,3 @@ class WordFilter(BaseFilter): def defilter(self, query: str) -> str: return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip() - - def apply(self, query, entries): - "Find entries containing required and not blocked words specified in query" - # Separate natural query from required, blocked words filters - with timer("Extract required, blocked filters from query", logger): - required_words = set([word.lower() for word in re.findall(self.required_regex, query)]) - blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)]) - query = self.defilter(query) - - if len(required_words) == 0 and len(blocked_words) == 0: - return query, set(range(len(entries))) - - # Return item from cache if exists - cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) - if cache_key in self.cache: - logger.debug(f"Return word filter results from cache") - included_entry_indices = self.cache[cache_key] - return query, included_entry_indices - - if not self.word_to_entry_index: - self.load(entries, regenerate=False) - - # mark entries that contain all required_words for inclusion - with timer("Mark entries satisfying filter", logger): - entries_with_all_required_words = set(range(len(entries))) - if len(required_words) > 0: - entries_with_all_required_words = set.intersection( - *[self.word_to_entry_index.get(word, set()) for word in required_words] - ) - - # mark entries that contain any blocked_words for exclusion - entries_with_any_blocked_words = set() - if len(blocked_words) > 0: - entries_with_any_blocked_words = set.union( - *[self.word_to_entry_index.get(word, set()) for word in blocked_words] - ) - - # get entries satisfying inclusion and exclusion filters - included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words - - # Cache results - self.cache[cache_key] = included_entry_indices - - return query, included_entry_indices diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 2890baa9..36d6a791 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -2,25 +2,39 @@ import logging import math from pathlib import Path -from typing import List, Tuple, Type, Union +from typing import List, Tuple, Type, Union, Dict # External Packages import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util -from khoj.processor.text_to_jsonl import TextToJsonl -from khoj.search_filter.base_filter import BaseFilter + +from asgiref.sync import sync_to_async + # Internal Packages from khoj.utils import state -from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer -from khoj.utils.config import TextContent, TextSearchModel +from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer +from khoj.utils.config import TextSearchModel from khoj.utils.models import BaseEncoder -from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry +from khoj.utils.state import SearchType +from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, Entry from khoj.utils.jsonl import load_jsonl - +from khoj.processor.text_to_jsonl import TextEmbeddings +from database.adapters import EmbeddingsAdapters +from database.models import KhojUser, Embeddings logger = logging.getLogger(__name__) +search_type_to_embeddings_type = { + SearchType.Org.value: Embeddings.EmbeddingsType.ORG, + SearchType.Markdown.value: Embeddings.EmbeddingsType.MARKDOWN, + SearchType.Plaintext.value: Embeddings.EmbeddingsType.PLAINTEXT, + SearchType.Pdf.value: Embeddings.EmbeddingsType.PDF, + SearchType.Github.value: Embeddings.EmbeddingsType.GITHUB, + SearchType.Notion.value: Embeddings.EmbeddingsType.NOTION, + SearchType.All.value: None, +} + def initialize_model(search_config: TextSearchConfig): "Initialize model for semantic search on text" @@ -117,171 +131,102 @@ def load_embeddings( async def query( + user: KhojUser, raw_query: str, - search_model: TextSearchModel, - content: TextContent, + type: SearchType = SearchType.All, question_embedding: Union[torch.Tensor, None] = None, rank_results: bool = False, score_threshold: float = -math.inf, - dedupe: bool = True, ) -> Tuple[List[dict], List[Entry]]: "Search for entries that answer the query" - if ( - content.entries is None - or len(content.entries) == 0 - or content.corpus_embeddings is None - or len(content.corpus_embeddings) == 0 - ): - return [], [] - query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings + file_type = search_type_to_embeddings_type[type.value] - # Filter query, entries and embeddings before semantic search - query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters) - - # If no entries left after filtering, return empty results - if entries is None or len(entries) == 0: - return [], [] - # If query only had filters it'll be empty now. So short-circuit and return results. - if query.strip() == "": - hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)] - return hits, entries + query = raw_query # Encode the query using the bi-encoder if question_embedding is None: with timer("Query Encode Time", logger, state.device): - question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) - question_embedding = util.normalize_embeddings(question_embedding) + question_embedding = state.embeddings_model.embed_query(query) # Find relevant entries for the query - top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus + top_k = 10 with timer("Search Time", logger, state.device): - hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0] + hits = EmbeddingsAdapters.search_with_embeddings( + user=user, + embeddings=question_embedding, + max_results=top_k, + file_type_filter=file_type, + raw_query=raw_query, + ).all() + hits = await sync_to_async(list)(hits) # type: ignore[call-arg] + return hits + + +def collate_results(hits, dedupe=True): + hit_ids = set() + for hit in hits: + if dedupe and hit.corpus_id in hit_ids: + continue + + else: + hit_ids.add(hit.corpus_id) + yield SearchResponse.parse_obj( + { + "entry": hit.raw, + "score": hit.distance, + "additional": { + "file": hit.file_path, + "compiled": hit.compiled, + "heading": hit.heading, + }, + } + ) + + +def rerank_and_sort_results(hits, query): # Score all retrieved entries using the cross-encoder - if rank_results and search_model.cross_encoder: - hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits) + hits = cross_encoder_score(query, hits) - # Filter results by score threshold - hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold] + # Sort results by cross-encoder score followed by bi-encoder score + hits = sort_results(rank_results=True, hits=hits) - # Order results by cross-encoder score followed by bi-encoder score - hits = sort_results(rank_results, hits) - - # Deduplicate entries by raw entry text before showing to users - if dedupe: - hits = deduplicate_results(entries, hits) - - return hits, entries - - -def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]: - return [ - SearchResponse.parse_obj( - { - "entry": entries[hit["corpus_id"]].raw, - "score": f"{hit.get('cross-score') or hit.get('score')}", - "additional": { - "file": entries[hit["corpus_id"]].file, - "compiled": entries[hit["corpus_id"]].compiled, - "heading": entries[hit["corpus_id"]].heading, - }, - } - ) - for hit in hits[0:count] - ] + return hits def setup( - text_to_jsonl: Type[TextToJsonl], + text_to_jsonl: Type[TextEmbeddings], files: dict[str, str], - config: TextConfigBase, - bi_encoder: BaseEncoder, regenerate: bool, - filters: List[BaseFilter] = [], - normalize: bool = True, full_corpus: bool = True, -) -> TextContent: - # Map notes in text files to (compressed) JSONL formatted file - config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) - previous_entries = [] - if config.compressed_jsonl.exists() and not regenerate: - previous_entries = extract_entries(config.compressed_jsonl) - entries_with_indices = text_to_jsonl(config).process( - previous_entries=previous_entries, files=files, full_corpus=full_corpus - ) - - # Extract Updated Entries - entries = extract_entries(config.compressed_jsonl) - if is_none_or_empty(entries): - config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()]) - raise ValueError( - f"No valid entries found in specified configuration: {config_params}, with files: {files.keys()}" + user: KhojUser = None, + config=None, +) -> None: + if config: + num_new_embeddings, num_deleted_embeddings = text_to_jsonl(config).process( + files=files, full_corpus=full_corpus, user=user, regenerate=regenerate + ) + else: + num_new_embeddings, num_deleted_embeddings = text_to_jsonl().process( + files=files, full_corpus=full_corpus, user=user, regenerate=regenerate ) - # Compute or Load Embeddings - config.embeddings_file = resolve_absolute_path(config.embeddings_file) - corpus_embeddings = compute_embeddings( - entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize + file_names = [file_name for file_name in files] + + logger.info( + f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}" ) - for filter in filters: - filter.load(entries, regenerate=regenerate) - return TextContent(entries, corpus_embeddings, filters) - - -def load( - config: TextConfigBase, - filters: List[BaseFilter] = [], -) -> TextContent: - # Map notes in text files to (compressed) JSONL formatted file - config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) - entries = extract_entries(config.compressed_jsonl) - - # Compute or Load Embeddings - config.embeddings_file = resolve_absolute_path(config.embeddings_file) - corpus_embeddings = load_embeddings(config.embeddings_file) - - for filter in filters: - filter.load(entries, regenerate=False) - - return TextContent(entries, corpus_embeddings, filters) - - -def apply_filters( - query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter] -) -> Tuple[str, List[Entry], torch.Tensor]: - """Filter query, entries and embeddings before semantic search""" - - with timer("Total Filter Time", logger, state.device): - included_entry_indices = set(range(len(entries))) - filters_in_query = [filter for filter in filters if filter.can_filter(query)] - for filter in filters_in_query: - query, included_entry_indices_by_filter = filter.apply(query, entries) - included_entry_indices.intersection_update(included_entry_indices_by_filter) - - # Get entries (and associated embeddings) satisfying all filters - if not included_entry_indices: - return "", [], torch.tensor([], device=state.device) - else: - entries = [entries[id] for id in included_entry_indices] - corpus_embeddings = torch.index_select( - corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device) - ) - - return query, entries, corpus_embeddings - - -def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]: +def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]: """Score all retrieved entries using the cross-encoder""" with timer("Cross-Encoder Predict Time", logger, state.device): - cross_inp = [[query, entries[hit["corpus_id"]].compiled] for hit in hits] - cross_scores = cross_encoder.predict(cross_inp) + cross_scores = state.cross_encoder_model.predict(query, hits) # Store cross-encoder scores in results dictionary for ranking for idx in range(len(cross_scores)): - hits[idx]["cross-score"] = cross_scores[idx] + hits[idx]["cross_score"] = cross_scores[idx] return hits @@ -291,23 +236,5 @@ def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]: with timer("Rank Time", logger, state.device): hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score if rank_results: - hits.sort(key=lambda x: x["cross-score"], reverse=True) # sort by cross-encoder score - return hits - - -def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]: - """Deduplicate entries by raw entry text before showing to users - Compiled entries are split by max tokens supported by ML models. - This can result in duplicate hits, entries shown to user.""" - - with timer("Deduplication Time", logger, state.device): - seen, original_hits_count = set(), len(hits) - hits = [ - hit - for hit in hits - if entries[hit["corpus_id"]].raw not in seen and not seen.add(entries[hit["corpus_id"]].raw) # type: ignore[func-returns-value] - ] - duplicate_hits = original_hits_count - len(hits) - - logger.debug(f"Removed {duplicate_hits} duplicates") + hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score return hits diff --git a/src/khoj/utils/cli.py b/src/khoj/utils/cli.py index 1d6106cb..c72320a1 100644 --- a/src/khoj/utils/cli.py +++ b/src/khoj/utils/cli.py @@ -2,6 +2,7 @@ import argparse import pathlib from importlib.metadata import version +import os # Internal Packages from khoj.utils.helpers import resolve_absolute_path @@ -34,6 +35,12 @@ def cli(args=None): ) parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit") parser.add_argument("--demo", action="store_true", default=False, help="Run Khoj in demo mode") + parser.add_argument( + "--anonymous-mode", + action="store_true", + default=False, + help="Run Khoj in anonymous mode. This does not require any login for connecting users.", + ) args = parser.parse_args(args) @@ -51,6 +58,8 @@ def cli(args=None): else: args = run_migrations(args) args.config = parse_config_from_file(args.config_file) + if os.environ.get("DEBUG"): + args.config.app.should_log_telemetry = False return args diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 5b3b9f6e..ee5b4f9f 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -41,9 +41,7 @@ class ProcessorType(str, Enum): @dataclass class TextContent: - entries: List[Entry] - corpus_embeddings: torch.Tensor - filters: List[BaseFilter] + enabled: bool @dataclass @@ -67,21 +65,13 @@ class ImageSearchModel: @dataclass class ContentIndex: - org: Optional[TextContent] = None - markdown: Optional[TextContent] = None - pdf: Optional[TextContent] = None - github: Optional[TextContent] = None - notion: Optional[TextContent] = None image: Optional[ImageContent] = None - plaintext: Optional[TextContent] = None - plugins: Optional[Dict[str, TextContent]] = None @dataclass class SearchModels: text_search: Optional[TextSearchModel] = None image_search: Optional[ImageSearchModel] = None - plugin_search: Optional[Dict[str, TextSearchModel]] = None @dataclass diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index 9ed97798..181dee04 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -5,6 +5,7 @@ web_directory = app_root_directory / "khoj/interface/web/" empty_escape_sequences = "\n|\r|\t| " app_env_filepath = "~/.khoj/env" telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry" +content_directory = "~/.khoj/content/" empty_config = { "content-type": { diff --git a/src/khoj/utils/fs_syncer.py b/src/khoj/utils/fs_syncer.py index 44fc70ad..fc7e4a2d 100644 --- a/src/khoj/utils/fs_syncer.py +++ b/src/khoj/utils/fs_syncer.py @@ -5,29 +5,39 @@ from typing import Optional from bs4 import BeautifulSoup from khoj.utils.helpers import get_absolute_path, is_none_or_empty -from khoj.utils.rawconfig import TextContentConfig, ContentConfig +from khoj.utils.rawconfig import TextContentConfig from khoj.utils.config import SearchType +from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig logger = logging.getLogger(__name__) -def collect_files(config: ContentConfig, search_type: Optional[SearchType] = SearchType.All): +def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict: files = {} - if config is None: - return files - if search_type == SearchType.All or search_type == SearchType.Org: - files["org"] = get_org_files(config.org) if config.org else {} + org_config = LocalOrgConfig.objects.filter(user=user).first() + files["org"] = get_org_files(construct_config_from_db(org_config)) if org_config else {} if search_type == SearchType.All or search_type == SearchType.Markdown: - files["markdown"] = get_markdown_files(config.markdown) if config.markdown else {} + markdown_config = LocalMarkdownConfig.objects.filter(user=user).first() + files["markdown"] = get_markdown_files(construct_config_from_db(markdown_config)) if markdown_config else {} if search_type == SearchType.All or search_type == SearchType.Plaintext: - files["plaintext"] = get_plaintext_files(config.plaintext) if config.plaintext else {} + plaintext_config = LocalPlaintextConfig.objects.filter(user=user).first() + files["plaintext"] = get_plaintext_files(construct_config_from_db(plaintext_config)) if plaintext_config else {} if search_type == SearchType.All or search_type == SearchType.Pdf: - files["pdf"] = get_pdf_files(config.pdf) if config.pdf else {} + pdf_config = LocalPdfConfig.objects.filter(user=user).first() + files["pdf"] = get_pdf_files(construct_config_from_db(pdf_config)) if pdf_config else {} return files +def construct_config_from_db(db_config) -> TextContentConfig: + return TextContentConfig( + input_files=db_config.input_files, + input_filter=db_config.input_filter, + index_heading_entries=db_config.index_heading_entries, + ) + + def get_plaintext_files(config: TextContentConfig) -> dict[str, str]: def is_plaintextfile(file: str): "Check if file is plaintext file" diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 9209ff67..e41791f9 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -209,10 +209,12 @@ def log_telemetry( if not app_config or not app_config.should_log_telemetry: return [] + if properties.get("server_id") is None: + properties["server_id"] = get_server_id() + # Populate telemetry data to log request_body = { "telemetry_type": telemetry_type, - "server_id": get_server_id(), "server_version": version("khoj-assistant"), "os": platform.system(), "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index f7c42266..5d2b3ce4 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -1,13 +1,14 @@ # System Packages import json from pathlib import Path -from typing import List, Dict, Optional, Union, Any +from typing import List, Dict, Optional +import uuid # External Packages -from pydantic import BaseModel, validator +from pydantic import BaseModel # Internal Packages -from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty +from khoj.utils.helpers import to_snake_case_from_dash class ConfigBase(BaseModel): @@ -27,7 +28,7 @@ class TextConfigBase(ConfigBase): embeddings_file: Path -class TextContentConfig(TextConfigBase): +class TextContentConfig(ConfigBase): input_files: Optional[List[Path]] input_filter: Optional[List[str]] index_heading_entries: Optional[bool] = False @@ -39,12 +40,12 @@ class GithubRepoConfig(ConfigBase): branch: Optional[str] = "master" -class GithubContentConfig(TextConfigBase): +class GithubContentConfig(ConfigBase): pat_token: str repos: List[GithubRepoConfig] -class NotionContentConfig(TextConfigBase): +class NotionContentConfig(ConfigBase): token: str @@ -63,7 +64,6 @@ class ContentConfig(ConfigBase): pdf: Optional[TextContentConfig] plaintext: Optional[TextContentConfig] github: Optional[GithubContentConfig] - plugins: Optional[Dict[str, TextContentConfig]] notion: Optional[NotionContentConfig] @@ -122,7 +122,8 @@ class FullConfig(ConfigBase): class SearchResponse(ConfigBase): entry: str - score: str + score: float + cross_score: Optional[float] additional: Optional[dict] @@ -131,14 +132,21 @@ class Entry: compiled: str heading: Optional[str] file: Optional[str] + corpus_id: str def __init__( - self, raw: str = None, compiled: str = None, heading: Optional[str] = None, file: Optional[str] = None + self, + raw: str = None, + compiled: str = None, + heading: Optional[str] = None, + file: Optional[str] = None, + corpus_id: uuid.UUID = None, ): self.raw = raw self.compiled = compiled self.heading = heading self.file = file + self.corpus_id = str(corpus_id) def to_json(self) -> str: return json.dumps(self.__dict__, ensure_ascii=False) @@ -153,4 +161,5 @@ class Entry: compiled=dictionary["compiled"], file=dictionary.get("file", None), heading=dictionary.get("heading", None), + corpus_id=dictionary.get("corpus_id", None), ) diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 5ac8a838..d6169d2a 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -2,6 +2,7 @@ import threading from typing import List, Dict from packaging import version +from collections import defaultdict # External Packages import torch @@ -12,10 +13,13 @@ from khoj.utils import config as utils_config from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel from khoj.utils.helpers import LRU from khoj.utils.rawconfig import FullConfig +from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel # Application Global State config = FullConfig() search_models = SearchModels() +embeddings_model = EmbeddingsModel() +cross_encoder_model = CrossEncoderModel() content_index = ContentIndex() processor_config = ProcessorConfigModel() config_file: Path = None @@ -23,14 +27,14 @@ verbose: int = 0 host: str = None port: int = None cli_args: List[str] = None -query_cache = LRU() +query_cache: Dict[str, LRU] = defaultdict(LRU) config_lock = threading.Lock() chat_lock = threading.Lock() SearchType = utils_config.SearchType telemetry: List[Dict[str, str]] = [] -previous_query: str = None demo: bool = False khoj_version: str = None +anonymous_mode: bool = False if torch.cuda.is_available(): # Use CUDA GPU diff --git a/tests/conftest.py b/tests/conftest.py index 4f7dfb10..5f515ef1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,19 @@ # External Packages import os -from copy import deepcopy from fastapi.testclient import TestClient from pathlib import Path import pytest from fastapi.staticfiles import StaticFiles +from fastapi import FastAPI +import factory +import os +from fastapi import FastAPI + +app = FastAPI() + # Internal Packages -from app.main import app from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware -from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl from khoj.search_type import image_search, text_search from khoj.utils.config import SearchModels @@ -22,8 +26,6 @@ from khoj.utils.rawconfig import ( OpenAIProcessorConfig, ProcessorConfig, TextContentConfig, - GithubContentConfig, - GithubRepoConfig, ImageContentConfig, SearchConfig, TextSearchConfig, @@ -31,11 +33,31 @@ from khoj.utils.rawconfig import ( ) from khoj.utils import state, fs_syncer from khoj.routers.indexer import configure_content -from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl -from khoj.search_filter.date_filter import DateFilter -from khoj.search_filter.word_filter import WordFilter -from khoj.search_filter.file_filter import FileFilter +from database.models import ( + LocalOrgConfig, + LocalMarkdownConfig, + LocalPlaintextConfig, + LocalPdfConfig, + GithubConfig, + KhojUser, + GithubRepoConfig, +) + + +@pytest.fixture(autouse=True) +def enable_db_access_for_all_tests(db): + pass + + +class UserFactory(factory.django.DjangoModelFactory): + class Meta: + model = KhojUser + + username = factory.Faker("name") + email = factory.Faker("email") + password = factory.Faker("password") + uuid = factory.Faker("uuid4") @pytest.fixture(scope="session") @@ -67,17 +89,28 @@ def search_config() -> SearchConfig: return search_config +@pytest.mark.django_db +@pytest.fixture +def default_user(): + return UserFactory() + + @pytest.fixture(scope="session") def search_models(search_config: SearchConfig): search_models = SearchModels() - search_models.text_search = text_search.initialize_model(search_config.asymmetric) search_models.image_search = image_search.initialize_model(search_config.image) return search_models -@pytest.fixture(scope="session") -def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig): +@pytest.fixture +def anyio_backend(): + return "asyncio" + + +@pytest.mark.django_db +@pytest.fixture(scope="function") +def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser): content_dir = tmp_path_factory.mktemp("content") # Generate Image Embeddings from Test Images @@ -92,94 +125,45 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config: image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False) - # Generate Notes Embeddings from Test Notes - content_config.org = TextContentConfig( + LocalOrgConfig.objects.create( input_files=None, input_filter=["tests/data/org/*.org"], - compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"), - embeddings_file=content_dir.joinpath("note_embeddings.pt"), + index_heading_entries=False, + user=default_user, ) - filters = [DateFilter(), WordFilter(), FileFilter()] - text_search.setup( - OrgToJsonl, - get_sample_data("org"), - content_config.org, - search_models.text_search.bi_encoder, - regenerate=False, - filters=filters, - ) - - content_config.plugins = { - "plugin1": TextContentConfig( - input_files=[content_dir.joinpath("notes.jsonl.gz")], - input_filter=None, - compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"), - embeddings_file=content_dir.joinpath("plugin_embeddings.pt"), - ) - } + text_search.setup(OrgToJsonl, get_sample_data("org"), regenerate=False, user=default_user) if os.getenv("GITHUB_PAT_TOKEN"): - content_config.github = GithubContentConfig( - pat_token=os.getenv("GITHUB_PAT_TOKEN", ""), - repos=[ - GithubRepoConfig( - owner="khoj-ai", - name="lantern", - branch="master", - ) - ], - compressed_jsonl=content_dir.joinpath("github.jsonl.gz"), - embeddings_file=content_dir.joinpath("github_embeddings.pt"), + GithubConfig.objects.create( + pat_token=os.getenv("GITHUB_PAT_TOKEN"), + user=default_user, ) - content_config.plaintext = TextContentConfig( + GithubRepoConfig.objects.create( + owner="khoj-ai", + name="lantern", + branch="master", + github_config=GithubConfig.objects.get(user=default_user), + ) + + LocalPlaintextConfig.objects.create( input_files=None, input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"], - compressed_jsonl=content_dir.joinpath("plaintext.jsonl.gz"), - embeddings_file=content_dir.joinpath("plaintext_embeddings.pt"), - ) - - content_config.github = GithubContentConfig( - pat_token=os.getenv("GITHUB_PAT_TOKEN", ""), - repos=[ - GithubRepoConfig( - owner="khoj-ai", - name="lantern", - branch="master", - ) - ], - compressed_jsonl=content_dir.joinpath("github.jsonl.gz"), - embeddings_file=content_dir.joinpath("github_embeddings.pt"), - ) - - filters = [DateFilter(), WordFilter(), FileFilter()] - text_search.setup( - JsonlToJsonl, - None, - content_config.plugins["plugin1"], - search_models.text_search.bi_encoder, - regenerate=False, - filters=filters, + user=default_user, ) return content_config @pytest.fixture(scope="session") -def md_content_config(tmp_path_factory): - content_dir = tmp_path_factory.mktemp("content") - - # Generate Embeddings for Markdown Content - content_config = ContentConfig() - content_config.markdown = TextContentConfig( +def md_content_config(): + markdown_config = LocalMarkdownConfig.objects.create( input_files=None, input_filter=["tests/data/markdown/*.markdown"], - compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"), - embeddings_file=content_dir.joinpath("markdown_embeddings.pt"), ) - return content_config + return markdown_config @pytest.fixture(scope="session") @@ -220,19 +204,20 @@ def processor_config_offline_chat(tmp_path_factory): @pytest.fixture(scope="session") def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig): # Initialize app state - state.config.content_type = md_content_config state.config.search_type = search_config state.SearchType = configure_search_types(state.config) # Index Markdown Content for Search - state.search_models.text_search = text_search.initialize_model(search_config.asymmetric) - all_files = fs_syncer.collect_files(state.config.content_type) + all_files = fs_syncer.collect_files() state.content_index = configure_content( state.content_index, state.config.content_type, all_files, state.search_models ) # Initialize Processor from Config state.processor_config = configure_processor(processor_config) + state.anonymous_mode = True + + app = FastAPI() configure_routes(app) configure_middleware(app) @@ -241,33 +226,45 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p @pytest.fixture(scope="function") -def client(content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig): +def fastapi_app(): + app = FastAPI() + configure_routes(app) + configure_middleware(app) + app.mount("/static", StaticFiles(directory=web_directory), name="static") + return app + + +@pytest.fixture(scope="function") +def client( + content_config: ContentConfig, + search_config: SearchConfig, + processor_config: ProcessorConfig, + default_user: KhojUser, +): state.config.content_type = content_config state.config.search_type = search_config state.SearchType = configure_search_types(state.config) # These lines help us Mock the Search models for these search types - state.search_models.text_search = text_search.initialize_model(search_config.asymmetric) state.search_models.image_search = image_search.initialize_model(search_config.image) - state.content_index.org = text_search.setup( + text_search.setup( OrgToJsonl, get_sample_data("org"), - content_config.org, - state.search_models.text_search.bi_encoder, regenerate=False, + user=default_user, ) state.content_index.image = image_search.setup( content_config.image, state.search_models.image_search, regenerate=False ) - state.content_index.plaintext = text_search.setup( + text_search.setup( PlaintextToJsonl, get_sample_data("plaintext"), - content_config.plaintext, - state.search_models.text_search.bi_encoder, regenerate=False, + user=default_user, ) state.processor_config = configure_processor(processor_config) + state.anonymous_mode = True configure_routes(app) configure_middleware(app) @@ -288,7 +285,6 @@ def client_offline_chat( state.SearchType = configure_search_types(state.config) # Index Markdown Content for Search - state.search_models.text_search = text_search.initialize_model(search_config.asymmetric) state.search_models.image_search = image_search.initialize_model(search_config.image) all_files = fs_syncer.collect_files(state.config.content_type) @@ -298,6 +294,7 @@ def client_offline_chat( # Initialize Processor from Config state.processor_config = configure_processor(processor_config_offline_chat) + state.anonymous_mode = True configure_routes(app) configure_middleware(app) @@ -306,9 +303,11 @@ def client_offline_chat( @pytest.fixture(scope="function") -def new_org_file(content_config: ContentConfig): +def new_org_file(default_user: KhojUser, content_config: ContentConfig): # Setup - new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org" + org_config = LocalOrgConfig.objects.filter(user=default_user).first() + input_filters = org_config.input_filter + new_org_file = Path(input_filters[0]).parent / "new_file.org" new_org_file.touch() yield new_org_file @@ -319,11 +318,9 @@ def new_org_file(content_config: ContentConfig): @pytest.fixture(scope="function") -def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path): - new_org_config = deepcopy(content_config.org) - new_org_config.input_files = [f"{new_org_file}"] - new_org_config.input_filter = None - return new_org_config +def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser): + LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None) + return LocalOrgConfig.objects.filter(user=default_user).first() @pytest.fixture(scope="function") diff --git a/tests/data/config.yml b/tests/data/config.yml index 06978cf1..c544eebe 100644 --- a/tests/data/config.yml +++ b/tests/data/config.yml @@ -9,17 +9,6 @@ content-type: input-filter: - '*.org' - ~/notes/*.org - plugins: - content_plugin_1: - compressed-jsonl: content_plugin_1.jsonl.gz - embeddings-file: content_plugin_1_embeddings.pt - input-files: - - content_plugin_1_new.jsonl.gz - content_plugin_2: - compressed-jsonl: content_plugin_2.jsonl.gz - embeddings-file: content_plugin_2_embeddings.pt - input-filter: - - '*2_new.jsonl.gz' enable-offline-chat: false search-type: asymmetric: diff --git a/tests/test_cli.py b/tests/test_cli.py index 9de3a853..cff2a7f3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -48,14 +48,3 @@ def test_cli_config_from_file(): Path("~/first_from_config.org"), Path("~/second_from_config.org"), ] - assert len(actual_args.config.content_type.plugins.keys()) == 2 - assert actual_args.config.content_type.plugins["content_plugin_1"].input_files == [ - Path("content_plugin_1_new.jsonl.gz") - ] - assert actual_args.config.content_type.plugins["content_plugin_2"].input_filter == ["*2_new.jsonl.gz"] - assert actual_args.config.content_type.plugins["content_plugin_1"].compressed_jsonl == Path( - "content_plugin_1.jsonl.gz" - ) - assert actual_args.config.content_type.plugins["content_plugin_2"].embeddings_file == Path( - "content_plugin_2_embeddings.pt" - ) diff --git a/tests/test_client.py b/tests/test_client.py index a5f14882..f63b968c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,22 +2,21 @@ from io import BytesIO from PIL import Image from urllib.parse import quote - +import pytest # External Packages from fastapi.testclient import TestClient -import pytest +from fastapi import FastAPI # Internal Packages -from app.main import app from khoj.configure import configure_routes, configure_search_types from khoj.utils import state from khoj.utils.state import search_models, content_index, config from khoj.search_type import text_search, image_search from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl -from khoj.search_filter.word_filter import WordFilter -from khoj.search_filter.file_filter import FileFilter +from database.models import KhojUser +from database.adapters import EmbeddingsAdapters # Test @@ -35,7 +34,7 @@ def test_search_with_invalid_content_type(client): # ---------------------------------------------------------------------------------------------------- def test_search_with_valid_content_type(client): - for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plugin1"]: + for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion"]: # Act response = client.get(f"/api/search?q=random&t={content_type}") # Assert @@ -75,7 +74,7 @@ def test_index_update(client): # ---------------------------------------------------------------------------------------------------- def test_regenerate_with_valid_content_type(client): - for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]: + for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]: # Arrange files = get_sample_files_data() headers = {"x-api-key": "secret"} @@ -102,60 +101,42 @@ def test_regenerate_with_github_fails_without_pat(client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db @pytest.mark.skip(reason="Flaky test on parallel test runs") -def test_get_configured_types_via_api(client): +def test_get_configured_types_via_api(client, sample_org_data): # Act - response = client.get(f"/api/config/types") + text_search.setup(OrgToJsonl, sample_org_data, regenerate=False) + + enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True) # Assert - assert response.status_code == 200 - assert response.json() == ["all", "org", "image", "plaintext", "plugin1"] + assert list(enabled_types) == ["org"] # ---------------------------------------------------------------------------------------------------- -def test_get_configured_types_with_only_plugin_content_config(content_config): +@pytest.mark.django_db(transaction=True) +def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data): # Arrange - config.content_type = ContentConfig() - config.content_type.plugins = content_config.plugins - state.SearchType = configure_search_types(config) - - configure_routes(app) - client = TestClient(app) + text_search.setup(OrgToJsonl, sample_org_data, regenerate=False) # Act response = client.get(f"/api/config/types") # Assert assert response.status_code == 200 - assert response.json() == ["all", "plugin1"] + assert response.json() == ["all", "org", "markdown", "image"] # ---------------------------------------------------------------------------------------------------- -def test_get_configured_types_with_no_plugin_content_config(content_config): +@pytest.mark.django_db(transaction=True) +def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI): # Arrange - config.content_type = content_config - config.content_type.plugins = None state.SearchType = configure_search_types(config) + original_config = state.config.content_type + state.config.content_type = None - configure_routes(app) - client = TestClient(app) - - # Act - response = client.get(f"/api/config/types") - - # Assert - assert response.status_code == 200 - assert "plugin1" not in response.json() - - -# ---------------------------------------------------------------------------------------------------- -def test_get_configured_types_with_no_content_config(): - # Arrange - config.content_type = ContentConfig() - state.SearchType = configure_search_types(config) - - configure_routes(app) - client = TestClient(app) + configure_routes(fastapi_app) + client = TestClient(fastapi_app) # Act response = client.get(f"/api/config/types") @@ -164,6 +145,9 @@ def test_get_configured_types_with_no_content_config(): assert response.status_code == 200 assert response.json() == ["all"] + # Restore + state.config.content_type = original_config + # ---------------------------------------------------------------------------------------------------- def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig): @@ -192,12 +176,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear # ---------------------------------------------------------------------------------------------------- -def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data): +@pytest.mark.django_db(transaction=True) +def test_notes_search(client, search_config: SearchConfig, sample_org_data): # Arrange - search_models.text_search = text_search.initialize_model(search_config.asymmetric) - content_index.org = text_search.setup( - OrgToJsonl, sample_org_data, content_config.org, search_models.text_search.bi_encoder, regenerate=False - ) + text_search.setup(OrgToJsonl, sample_org_data, regenerate=False) user_query = quote("How to git install application?") # Act @@ -211,19 +193,15 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) def test_notes_search_with_only_filters( client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data ): # Arrange - filters = [WordFilter(), FileFilter()] - search_models.text_search = text_search.initialize_model(search_config.asymmetric) - content_index.org = text_search.setup( + text_search.setup( OrgToJsonl, sample_org_data, - content_config.org, - search_models.text_search.bi_encoder, regenerate=False, - filters=filters, ) user_query = quote('+"Emacs" file:"*.org"') @@ -238,15 +216,10 @@ def test_notes_search_with_only_filters( # ---------------------------------------------------------------------------------------------------- -def test_notes_search_with_include_filter( - client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data -): +@pytest.mark.django_db(transaction=True) +def test_notes_search_with_include_filter(client, sample_org_data): # Arrange - filters = [WordFilter()] - search_models.text_search = text_search.initialize_model(search_config.asymmetric) - content_index.org = text_search.setup( - OrgToJsonl, sample_org_data, content_config.org, search_models.text_search, regenerate=False, filters=filters - ) + text_search.setup(OrgToJsonl, sample_org_data, regenerate=False) user_query = quote('How to git install application? +"Emacs"') # Act @@ -260,19 +233,13 @@ def test_notes_search_with_include_filter( # ---------------------------------------------------------------------------------------------------- -def test_notes_search_with_exclude_filter( - client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data -): +@pytest.mark.django_db(transaction=True) +def test_notes_search_with_exclude_filter(client, sample_org_data): # Arrange - filters = [WordFilter()] - search_models.text_search = text_search.initialize_model(search_config.asymmetric) - content_index.org = text_search.setup( + text_search.setup( OrgToJsonl, sample_org_data, - content_config.org, - search_models.text_search.bi_encoder, regenerate=False, - filters=filters, ) user_query = quote('How to git install application? -"clone"') @@ -286,6 +253,22 @@ def test_notes_search_with_exclude_filter( assert "clone" not in search_result +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) +def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser): + # Arrange + text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user) + user_query = quote("How to git install application?") + + # Act + response = client.get(f"/api/search?q={user_query}&n=1&t=org") + + # Assert + assert response.status_code == 200 + # assert actual response has no data as the default_user is different from the user making the query (anonymous) + assert len(response.json()) == 0 + + def get_sample_files_data(): return { "files": ("path/to/filename.org", "* practicing piano", "text/org"), diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index d0d05bc5..f1f26d28 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -1,53 +1,12 @@ # Standard Packages import re from datetime import datetime -from math import inf # External Packages import pytest # Internal Packages from khoj.search_filter.date_filter import DateFilter -from khoj.utils.rawconfig import Entry - - -@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.") -def test_date_filter(): - entries = [ - Entry(compiled="Entry with no date", raw="Entry with no date"), - Entry(compiled="April Fools entry: 1984-04-01", raw="April Fools entry: 1984-04-01"), - Entry(compiled="Entry with date:1984-04-02", raw="Entry with date:1984-04-02"), - ] - - q_with_no_date_filter = "head tail" - ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries) - assert ret_query == "head tail" - assert entry_indices == {0, 1, 2} - - q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' - ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries) - assert ret_query == "head tail" - assert entry_indices == set() - - query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' - ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) - assert ret_query == "head tail" - assert entry_indices == {2} - - query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' - ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) - assert ret_query == "head tail" - assert entry_indices == {1} - - query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' - ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) - assert ret_query == "head tail" - assert entry_indices == {2} - - query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' - ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) - assert ret_query == "head tail" - assert entry_indices == {1, 2} @pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.") @@ -56,8 +15,8 @@ def test_extract_date_range(): datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp(), ] - assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()] - assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf] + assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [None, datetime(1984, 1, 2, 0, 0, 0).timestamp()] + assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), None] assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [ datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp(), diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index ed632d32..f5a903f8 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -6,97 +6,73 @@ from khoj.utils.rawconfig import Entry def test_no_file_filter(): # Arrange file_filter = FileFilter() - entries = arrange_content() q_with_no_filter = "head tail" # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == False - assert ret_query == "head tail" - assert entry_indices == {0, 1, 2, 3} def test_file_filter_with_non_existent_file(): # Arrange file_filter = FileFilter() - entries = arrange_content() q_with_no_filter = 'head file:"nonexistent.org" tail' # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True - assert ret_query == "head tail" - assert entry_indices == {} def test_single_file_filter(): # Arrange file_filter = FileFilter() - entries = arrange_content() q_with_no_filter = 'head file:"file 1.org" tail' # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True - assert ret_query == "head tail" - assert entry_indices == {0, 2} def test_file_filter_with_partial_match(): # Arrange file_filter = FileFilter() - entries = arrange_content() q_with_no_filter = 'head file:"1.org" tail' # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True - assert ret_query == "head tail" - assert entry_indices == {0, 2} def test_file_filter_with_regex_match(): # Arrange file_filter = FileFilter() - entries = arrange_content() q_with_no_filter = 'head file:"*.org" tail' # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True - assert ret_query == "head tail" - assert entry_indices == {0, 1, 2, 3} def test_multiple_file_filter(): # Arrange file_filter = FileFilter() - entries = arrange_content() q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"' # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True - assert ret_query == "head tail" - assert entry_indices == {0, 1, 2, 3} def test_get_file_filter_terms(): @@ -108,7 +84,7 @@ def test_get_file_filter_terms(): filter_terms = file_filter.get_filter_terms(q_with_filter_terms) # Assert - assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"'] + assert filter_terms == ["file 1\\.org", "/path/to/dir/.*\\.org"] def arrange_content(): diff --git a/tests/test_jsonl_to_jsonl.py b/tests/test_jsonl_to_jsonl.py deleted file mode 100644 index b52b5fc9..00000000 --- a/tests/test_jsonl_to_jsonl.py +++ /dev/null @@ -1,78 +0,0 @@ -# Internal Packages -from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl -from khoj.utils.rawconfig import Entry - - -def test_process_entries_from_single_input_jsonl(tmp_path): - "Convert multiple jsonl entries from single file to entries." - # Arrange - input_jsonl = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"} -{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"} -""" - input_jsonl_file = create_file(tmp_path, input_jsonl) - - # Act - # Process Each Entry from All Notes Files - input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file]) - entries = list(map(Entry.from_dict, input_jsons)) - output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries) - - # Assert - assert len(entries) == 2 - assert output_jsonl == input_jsonl - - -def test_process_entries_from_multiple_input_jsonls(tmp_path): - "Convert multiple jsonl entries from single file to entries." - # Arrange - input_jsonl_1 = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}""" - input_jsonl_2 = """{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}""" - input_jsonl_file_1 = create_file(tmp_path, input_jsonl_1, filename="input1.jsonl") - input_jsonl_file_2 = create_file(tmp_path, input_jsonl_2, filename="input2.jsonl") - - # Act - # Process Each Entry from All Notes Files - input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file_1, input_jsonl_file_2]) - entries = list(map(Entry.from_dict, input_jsons)) - output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries) - - # Assert - assert len(entries) == 2 - assert output_jsonl == f"{input_jsonl_1}\n{input_jsonl_2}\n" - - -def test_get_jsonl_files(tmp_path): - "Ensure JSONL files specified via input-filter, input-files extracted" - # Arrange - # Include via input-filter globs - group1_file1 = create_file(tmp_path, filename="group1-file1.jsonl") - group1_file2 = create_file(tmp_path, filename="group1-file2.jsonl") - group2_file1 = create_file(tmp_path, filename="group2-file1.jsonl") - group2_file2 = create_file(tmp_path, filename="group2-file2.jsonl") - # Include via input-file field - file1 = create_file(tmp_path, filename="notes.jsonl") - # Not included by any filter - create_file(tmp_path, filename="not-included-jsonl.jsonl") - create_file(tmp_path, filename="not-included-text.txt") - - expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1])) - - # Setup input-files, input-filters - input_files = [tmp_path / "notes.jsonl"] - input_filter = [tmp_path / "group1*.jsonl", tmp_path / "group2*.jsonl"] - - # Act - extracted_org_files = JsonlToJsonl.get_jsonl_files(input_files, input_filter) - - # Assert - assert len(extracted_org_files) == 5 - assert extracted_org_files == expected_files - - -# Helper Functions -def create_file(tmp_path, entry=None, filename="test.jsonl"): - jsonl_file = tmp_path / filename - jsonl_file.touch() - if entry: - jsonl_file.write_text(entry) - return jsonl_file diff --git a/tests/test_org_to_jsonl.py b/tests/test_org_to_jsonl.py index abf20d09..d47c212e 100644 --- a/tests/test_org_to_jsonl.py +++ b/tests/test_org_to_jsonl.py @@ -4,7 +4,7 @@ import os # Internal Packages from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl -from khoj.processor.text_to_jsonl import TextToJsonl +from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.utils.helpers import is_none_or_empty from khoj.utils.rawconfig import Entry from khoj.utils.fs_syncer import get_org_files @@ -63,7 +63,7 @@ def test_entry_split_when_exceeds_max_words(tmp_path): # Split each entry from specified Org files by max words jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl( - TextToJsonl.split_entries_by_max_tokens( + TextEmbeddings.split_entries_by_max_tokens( OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4 ) ) @@ -86,7 +86,7 @@ def test_entry_split_drops_large_words(): # Act # Split entry by max words and drop words larger than max word length - processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0] + processed_entry = TextEmbeddings.split_entries_by_max_tokens([entry], max_word_length=5)[0] # Assert # "Heading" dropped from compiled version because its over the set max word limit diff --git a/tests/test_plaintext_to_jsonl.py b/tests/test_plaintext_to_jsonl.py index a6da30e1..56c68e38 100644 --- a/tests/test_plaintext_to_jsonl.py +++ b/tests/test_plaintext_to_jsonl.py @@ -7,6 +7,7 @@ from pathlib import Path from khoj.utils.fs_syncer import get_plaintext_files from khoj.utils.rawconfig import TextContentConfig from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl +from database.models import LocalPlaintextConfig, KhojUser def test_plaintext_file(tmp_path): @@ -91,11 +92,12 @@ def test_get_plaintext_files(tmp_path): assert set(extracted_plaintext_files.keys()) == set(expected_files) -def test_parse_html_plaintext_file(content_config): +def test_parse_html_plaintext_file(content_config, default_user: KhojUser): "Ensure HTML files are parsed correctly" # Arrange # Setup input-files, input-filters - extracted_plaintext_files = get_plaintext_files(content_config.plaintext) + config = LocalPlaintextConfig.objects.filter(user=default_user).first() + extracted_plaintext_files = get_plaintext_files(config=config) # Act maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files) diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 179718fa..af47ffe5 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -3,23 +3,30 @@ import logging import locale from pathlib import Path import os +import asyncio # External Packages import pytest # Internal Packages -from khoj.utils.state import content_index, search_models from khoj.search_type import text_search +from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.github.github_to_jsonl import GithubToJsonl from khoj.utils.config import SearchModels -from khoj.utils.fs_syncer import get_org_files +from khoj.utils.fs_syncer import get_org_files, collect_files +from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig + +logger = logging.getLogger(__name__) from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig # Test # ---------------------------------------------------------------------------------------------------- -def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig): +@pytest.mark.django_db +def test_text_search_setup_with_missing_file_raises_error( + org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig +): # Arrange # Ensure file mentioned in org.input-files is missing single_new_file = Path(org_config_with_only_new_file.input_files[0]) @@ -32,98 +39,126 @@ def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_n # ---------------------------------------------------------------------------------------------------- -def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path: Path): +@pytest.mark.django_db +def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser): # Arrange orgfile = tmp_path / "directory.org" / "file.org" orgfile.parent.mkdir() with open(orgfile, "w") as f: f.write("* Heading\n- List item\n") - org_content_config = TextContentConfig( - input_filter=[f"{tmp_path}/**/*"], compressed_jsonl="test.jsonl", embeddings_file="test.pt" + + LocalOrgConfig.objects.create( + input_filter=[f"{tmp_path}/**/*"], + input_files=None, + user=default_user, ) + org_files = collect_files(user=default_user)["org"] + # Act # should not raise IsADirectoryError and return orgfile - assert get_org_files(org_content_config) == {f"{orgfile}": "* Heading\n- List item\n"} + assert org_files == {f"{orgfile}": "* Heading\n- List item\n"} # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db def test_text_search_setup_with_empty_file_raises_error( - org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig + org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog ): # Arrange data = get_org_files(org_config_with_only_new_file) # Act # Generate notes embeddings during asymmetric setup - with pytest.raises(ValueError, match=r"^No valid entries found*"): - text_search.setup(OrgToJsonl, data, org_config_with_only_new_file, search_config.asymmetric, regenerate=True) + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) + + assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message + verify_embeddings(0, default_user) # ---------------------------------------------------------------------------------------------------- -def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels): +@pytest.mark.django_db +def test_text_search_setup(content_config, default_user: KhojUser, caplog): # Arrange - data = get_org_files(content_config.org) - - # Act - # Regenerate notes embeddings during asymmetric setup - notes_model = text_search.setup( - OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True - ) + org_config = LocalOrgConfig.objects.filter(user=default_user).first() + data = get_org_files(org_config) + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) # Assert - assert len(notes_model.entries) == 10 - assert len(notes_model.corpus_embeddings) == 10 + assert "Deleting all embeddings for file type org" in caplog.records[1].message + assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message # ---------------------------------------------------------------------------------------------------- -def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog): +@pytest.mark.django_db +def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog): # Arrange - caplog.set_level(logging.INFO, logger="khoj") - - data = get_org_files(content_config.org) + org_config = LocalOrgConfig.objects.filter(user=default_user).first() + data = get_org_files(org_config) # Act # Generate initial notes embeddings during asymmetric setup - text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True) + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) initial_logs = caplog.text caplog.clear() # Clear logs # Run asymmetric setup again with no changes to data source. Ensure index is not updated - text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False) + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) final_logs = caplog.text # Assert - assert "Creating index from scratch." in initial_logs - assert "Creating index from scratch." not in final_logs + assert "Deleting all embeddings for file type org" in initial_logs + assert "Deleting all embeddings for file type org" not in final_logs # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db @pytest.mark.anyio -async def test_text_search(content_config: ContentConfig, search_config: SearchConfig): +# @pytest.mark.asyncio +async def test_text_search(search_config: SearchConfig): # Arrange - data = get_org_files(content_config.org) - - search_models.text_search = text_search.initialize_model(search_config.asymmetric) - content_index.org = text_search.setup( - OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True + default_user = await KhojUser.objects.acreate( + username="test_user", password="test_password", email="test@example.com" ) + # Arrange + org_config = await LocalOrgConfig.objects.acreate( + input_files=None, + input_filter=["tests/data/org/*.org"], + index_heading_entries=False, + user=default_user, + ) + data = get_org_files(org_config) + + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, + text_search.setup, + OrgToJsonl, + data, + True, + True, + default_user, + ) + query = "How to git install application?" # Act - hits, entries = await text_search.query( - query, search_model=search_models.text_search, content=content_index.org, rank_results=True - ) - - results = text_search.collate_results(hits, entries, count=1) + hits = await text_search.query(default_user, query) # Assert + results = text_search.collate_results(hits) + results = sorted(results, key=lambda x: float(x.score))[:1] # search results should contain "git clone" entry search_result = results[0].entry assert "git clone" in search_result # ---------------------------------------------------------------------------------------------------- -def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels): +@pytest.mark.django_db +def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog): # Arrange # Insert org-mode entry with size exceeding max token limit to new org file max_tokens = 256 @@ -137,47 +172,46 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent # Act # reload embeddings, entries, notes model after adding new org-mode file - initial_notes_model = text_search.setup( - OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False - ) + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) # Assert # verify newly added org-mode entry is split by max tokens - assert len(initial_notes_model.entries) == 2 - assert len(initial_notes_model.corpus_embeddings) == 2 + record = caplog.records[1] + assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message # ---------------------------------------------------------------------------------------------------- # @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests") +@pytest.mark.django_db def test_entry_chunking_by_max_tokens_not_full_corpus( - org_config_with_only_new_file: TextContentConfig, search_models: SearchModels + org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog ): # Arrange # Insert org-mode entry with size exceeding max token limit to new org file data = { "readme.org": """ * Khoj - /Allow natural language search on user content like notes, images using transformer based models/ +/Allow natural language search on user content like notes, images using transformer based models/ - All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline +All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline ** Dependencies - - Python3 - - [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]] +- Python3 +- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]] ** Install - #+begin_src shell - git clone https://github.com/khoj-ai/khoj && cd khoj - conda env create -f environment.yml - conda activate khoj - #+end_src""" +#+begin_src shell +git clone https://github.com/khoj-ai/khoj && cd khoj +conda env create -f environment.yml +conda activate khoj +#+end_src""" } text_search.setup( OrgToJsonl, data, - org_config_with_only_new_file, - search_models.text_search.bi_encoder, regenerate=False, + user=default_user, ) max_tokens = 256 @@ -191,64 +225,57 @@ def test_entry_chunking_by_max_tokens_not_full_corpus( # Act # reload embeddings, entries, notes model after adding new org-mode file - initial_notes_model = text_search.setup( - OrgToJsonl, - data, - org_config_with_only_new_file, - search_models.text_search.bi_encoder, - regenerate=False, - full_corpus=False, - ) + with caplog.at_level(logging.INFO): + text_search.setup( + OrgToJsonl, + data, + regenerate=False, + full_corpus=False, + user=default_user, + ) + + record = caplog.records[1] # Assert # verify newly added org-mode entry is split by max tokens - assert len(initial_notes_model.entries) == 5 - assert len(initial_notes_model.corpus_embeddings) == 5 + assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db def test_regenerate_index_with_new_entry( - content_config: ContentConfig, search_models: SearchModels, new_org_file: Path + content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog ): # Arrange - data = get_org_files(content_config.org) - initial_notes_model = text_search.setup( - OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True - ) + org_config = LocalOrgConfig.objects.filter(user=default_user).first() + data = get_org_files(org_config) - assert len(initial_notes_model.entries) == 10 - assert len(initial_notes_model.corpus_embeddings) == 10 + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) + + assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message # append org-mode entry to first org input file in config - content_config.org.input_files = [f"{new_org_file}"] + org_config.input_files = [f"{new_org_file}"] with open(new_org_file, "w") as f: f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") - data = get_org_files(content_config.org) + data = get_org_files(org_config) # Act # regenerate notes jsonl, model embeddings and model to include entry from new file - regenerated_notes_model = text_search.setup( - OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True - ) + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) # Assert - assert len(regenerated_notes_model.entries) == 11 - assert len(regenerated_notes_model.corpus_embeddings) == 11 - - # verify new entry appended to index, without disrupting order or content of existing entries - error_details = compare_index(initial_notes_model, regenerated_notes_model) - if error_details: - pytest.fail(error_details, False) - - # Cleanup - # reset input_files in config to empty list - content_config.org.input_files = [] + assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message + verify_embeddings(11, default_user) # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db def test_update_index_with_duplicate_entries_in_stable_order( - org_config_with_only_new_file: TextContentConfig, search_models: SearchModels + org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog ): # Arrange new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) @@ -262,30 +289,26 @@ def test_update_index_with_duplicate_entries_in_stable_order( # Act # load embeddings, entries, notes model after adding new org-mode file - initial_index = text_search.setup( - OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True - ) + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) data = get_org_files(org_config_with_only_new_file) # update embeddings, entries, notes model after adding new org-mode file - updated_index = text_search.setup( - OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False - ) + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) # Assert # verify only 1 entry added even if there are multiple duplicate entries - assert len(initial_index.entries) == len(updated_index.entries) == 1 - assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1 + assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message + assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message - # verify the same entry is added even when there are multiple duplicate entries - error_details = compare_index(initial_index, updated_index) - if error_details: - pytest.fail(error_details) + verify_embeddings(1, default_user) # ---------------------------------------------------------------------------------------------------- -def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels): +@pytest.mark.django_db +def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog): # Arrange new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) @@ -296,9 +319,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont data = get_org_files(org_config_with_only_new_file) # load embeddings, entries, notes model after adding new org file with 2 entries - initial_index = text_search.setup( - OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True - ) + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) # update embeddings, entries, notes model after removing an entry from the org file with open(new_file_to_index, "w") as f: @@ -307,87 +329,65 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont data = get_org_files(org_config_with_only_new_file) # Act - updated_index = text_search.setup( - OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False - ) + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) # Assert # verify only 1 entry added even if there are multiple duplicate entries - assert len(initial_index.entries) == len(updated_index.entries) + 1 - assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1 + assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message + assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message - # verify the same entry is added even when there are multiple duplicate entries - error_details = compare_index(updated_index, initial_index) - if error_details: - pytest.fail(error_details) + verify_embeddings(1, default_user) # ---------------------------------------------------------------------------------------------------- -def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path): +@pytest.mark.django_db +def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog): # Arrange - data = get_org_files(content_config.org) - initial_notes_model = text_search.setup( - OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False - ) + org_config = LocalOrgConfig.objects.filter(user=default_user).first() + data = get_org_files(org_config) + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user) # append org-mode entry to first org input file in config with open(new_org_file, "w") as f: new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" f.write(new_entry) - data = get_org_files(content_config.org) + data = get_org_files(org_config) # Act # update embeddings, entries with the newly added note - content_config.org.input_files = [f"{new_org_file}"] - final_notes_model = text_search.setup( - OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False - ) + with caplog.at_level(logging.INFO): + text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user) # Assert - assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1 - assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1 + assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message + assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message - # verify new entry appended to index, without disrupting order or content of existing entries - error_details = compare_index(initial_notes_model, final_notes_model) - if error_details: - pytest.fail(error_details, False) - - # Cleanup - # reset input_files in config to empty list - content_config.org.input_files = [] + verify_embeddings(11, default_user) # ---------------------------------------------------------------------------------------------------- @pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set") -def test_text_search_setup_github(content_config: ContentConfig, search_models: SearchModels): +def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser): + # Arrange + github_config = GithubConfig.objects.filter(user=default_user).first() # Act # Regenerate github embeddings to test asymmetric setup without caching - github_model = text_search.setup( - GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True + text_search.setup( + GithubToJsonl, + {}, + regenerate=True, + user=default_user, + config=github_config, ) # Assert - assert len(github_model.entries) > 1 + embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count() + assert embeddings > 1 -def compare_index(initial_notes_model, final_notes_model): - mismatched_entries, mismatched_embeddings = [], [] - for index in range(len(initial_notes_model.entries)): - if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json(): - mismatched_entries.append(index) - - # verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings - for index in range(len(initial_notes_model.corpus_embeddings)): - if not initial_notes_model.corpus_embeddings[index].allclose(final_notes_model.corpus_embeddings[index]): - mismatched_embeddings.append(index) - - error_details = "" - if mismatched_entries: - mismatched_entries_str = ",".join(map(str, mismatched_entries)) - error_details += f"Entries at {mismatched_entries_str} not equal\n" - if mismatched_embeddings: - mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings)) - error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n" - - return error_details +def verify_embeddings(expected_count, user): + embeddings = Embeddings.objects.filter(user=user, file_type="org").count() + assert embeddings == expected_count diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index 04f45506..2ede35e7 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -3,68 +3,40 @@ from khoj.search_filter.word_filter import WordFilter from khoj.utils.rawconfig import Entry -def test_no_word_filter(): - # Arrange - word_filter = WordFilter() - entries = arrange_content() - q_with_no_filter = "head tail" - - # Act - can_filter = word_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries) - - # Assert - assert can_filter == False - assert ret_query == "head tail" - assert entry_indices == {0, 1, 2, 3} - - def test_word_exclude_filter(): # Arrange word_filter = WordFilter() - entries = arrange_content() q_with_exclude_filter = 'head -"exclude_word" tail' # Act can_filter = word_filter.can_filter(q_with_exclude_filter) - ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries) # Assert assert can_filter == True - assert ret_query == "head tail" - assert entry_indices == {0, 2} def test_word_include_filter(): # Arrange word_filter = WordFilter() - entries = arrange_content() query_with_include_filter = 'head +"include_word" tail' # Act can_filter = word_filter.can_filter(query_with_include_filter) - ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries) # Assert assert can_filter == True - assert ret_query == "head tail" - assert entry_indices == {2, 3} def test_word_include_and_exclude_filter(): # Arrange word_filter = WordFilter() - entries = arrange_content() query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail' # Act can_filter = word_filter.can_filter(query_with_include_and_exclude_filter) - ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries) # Assert assert can_filter == True - assert ret_query == "head tail" - assert entry_indices == {2} def test_get_word_filter_terms():