[Multi-User Part 1]: Enable storage of settings for plaintext files based on user account (#498)

- Partition configuration for indexing local data based on user accounts
- Store indexed data in an underlying postgres db using the `pgvector` extension
- Add migrations for all relevant user data and embeddings generation. Very little performance optimization has been done for the lookup time
- Apply filters using SQL queries
- Start removing many server-level configuration settings
- Configure GitHub test actions to run during any PR. Update the test action to run in a containerized environment with a DB.
- Update the Docker image and docker-compose.yml to work with the new application design
This commit is contained in:
sabaimran 2023-10-26 09:42:29 -07:00 committed by GitHub
parent 963cd165eb
commit 216acf545f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
60 changed files with 1827 additions and 1792 deletions

48
.github/workflows/pre-commit.yml vendored Normal file
View file

@ -0,0 +1,48 @@
name: pre-commit
on:
pull_request:
paths:
- src/**
- tests/**
- config/**
- pyproject.toml
- .pre-commit-config.yml
- .github/workflows/test.yml
push:
branches:
- master
paths:
- src/khoj/**
- tests/**
- config/**
- pyproject.toml
- .pre-commit-config.yml
- .github/workflows/test.yml
jobs:
test:
name: Run Tests
runs-on: ubuntu-latest
strategy:
fail-fast: false
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: ⏬️ Install Dependencies
run: |
sudo apt update && sudo apt install -y libegl1
python -m pip install --upgrade pip
- name: ⬇️ Install Application
run: pip install --upgrade .[dev]
- name: 🌡️ Validate Application
run: pre-commit run --hook-stage manual --all

View file

@ -2,10 +2,8 @@ name: test
on: on:
pull_request: pull_request:
branches:
- 'master'
paths: paths:
- src/khoj/** - src/**
- tests/** - tests/**
- config/** - config/**
- pyproject.toml - pyproject.toml
@ -13,7 +11,7 @@ on:
- .github/workflows/test.yml - .github/workflows/test.yml
push: push:
branches: branches:
- 'master' - master
paths: paths:
- src/khoj/** - src/khoj/**
- tests/** - tests/**
@ -26,6 +24,7 @@ jobs:
test: test:
name: Run Tests name: Run Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: ubuntu:jammy
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@ -33,6 +32,17 @@ jobs:
- '3.9' - '3.9'
- '3.10' - '3.10'
- '3.11' - '3.11'
services:
postgres:
image: ankane/pgvector
env:
POSTGRES_PASSWORD: postgres
POSTGRES_USER: postgres
ports:
- 5432:5432
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
with: with:
@ -43,17 +53,37 @@ jobs:
with: with:
python-version: ${{ matrix.python_version }} python-version: ${{ matrix.python_version }}
- name: ⏬️ Install Dependencies - name: Install Git
run: | run: |
sudo apt update && sudo apt install -y libegl1 apt update && apt install -y git
- name: ⏬️ Install Dependencies
env:
DEBIAN_FRONTEND: noninteractive
run: |
apt update && apt install -y libegl1 sqlite3 libsqlite3-dev libsqlite3-0
- name: ⬇️ Install Postgres
env:
DEBIAN_FRONTEND: noninteractive
run : |
apt install -y postgresql postgresql-client && apt install -y postgresql-server-dev-14
- name: ⬇️ Install pip
run: |
apt install -y python3-pip
python -m ensurepip --upgrade
python -m pip install --upgrade pip python -m pip install --upgrade pip
- name: ⬇️ Install Application - name: ⬇️ Install Application
run: pip install --upgrade .[dev] run: sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && pip install --upgrade .[dev]
- name: 🌡️ Validate Application
run: pre-commit run --hook-stage manual --all
- name: 🧪 Test Application - name: 🧪 Test Application
env:
POSTGRES_HOST: postgres
POSTGRES_PORT: 5432
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
run: pytest run: pytest
timeout-minutes: 10 timeout-minutes: 10

View file

@ -8,13 +8,20 @@ RUN apt update -y && apt -y install python3-pip git
WORKDIR /app WORKDIR /app
# Install Application # Install Application
COPY . . COPY pyproject.toml .
COPY README.md .
RUN sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && \ RUN sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && \
pip install --no-cache-dir . pip install --no-cache-dir .
# Copy Source Code
COPY . .
# Set the PYTHONPATH environment variable in order for it to find the Django app.
ENV PYTHONPATH=/app/src:$PYTHONPATH
# Run the Application # Run the Application
# There are more arguments required for the application to run, # There are more arguments required for the application to run,
# but these should be passed in through the docker-compose.yml file. # but these should be passed in through the docker-compose.yml file.
ARG PORT ARG PORT
EXPOSE ${PORT} EXPOSE ${PORT}
ENTRYPOINT ["khoj"] ENTRYPOINT ["python3", "src/khoj/main.py"]

View file

@ -1,7 +1,21 @@
version: "3.9" version: "3.9"
services: services:
database:
image: ankane/pgvector
ports:
- "5432:5432"
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
volumes:
- khoj_db:/var/lib/postgresql/data/
server: server:
# Use the following line to use the latest version of khoj. Otherwise, it will build from source.
image: ghcr.io/khoj-ai/khoj:latest image: ghcr.io/khoj-ai/khoj:latest
# Uncomment the following line to build from source. This will take a few minutes. Comment the next two lines out if you want to use the offiicial image.
# build:
# context: .
ports: ports:
# If changing the local port (left hand side), no other changes required. # If changing the local port (left hand side), no other changes required.
# If changing the remote port (right hand side), # If changing the remote port (right hand side),
@ -26,8 +40,15 @@ services:
- ./tests/data/models/:/root/.khoj/search/ - ./tests/data/models/:/root/.khoj/search/
- khoj_config:/root/.khoj/ - khoj_config:/root/.khoj/
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/ # Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
environment:
- POSTGRES_DB=postgres
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=postgres
- POSTGRES_HOST=database
- POSTGRES_PORT=5432
command: --host="0.0.0.0" --port=42110 -vv command: --host="0.0.0.0" --port=42110 -vv
volumes: volumes:
khoj_config: khoj_config:
khoj_db:

View file

@ -59,13 +59,15 @@ dependencies = [
"requests >= 2.26.0", "requests >= 2.26.0",
"bs4 >= 0.0.1", "bs4 >= 0.0.1",
"anyio == 3.7.1", "anyio == 3.7.1",
"pymupdf >= 1.23.3", "pymupdf >= 1.23.5",
"django == 4.2.5", "django == 4.2.5",
"authlib == 1.2.1", "authlib == 1.2.1",
"gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'", "gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'",
"gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'", "gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'",
"itsdangerous == 2.1.2", "itsdangerous == 2.1.2",
"httpx == 0.25.0", "httpx == 0.25.0",
"pgvector == 0.2.3",
"psycopg2-binary == 2.9.9",
] ]
dynamic = ["version"] dynamic = ["version"]
@ -91,6 +93,8 @@ dev = [
"mypy >= 1.0.1", "mypy >= 1.0.1",
"black >= 23.1.0", "black >= 23.1.0",
"pre-commit >= 3.0.4", "pre-commit >= 3.0.4",
"pytest-django == 4.5.2",
"pytest-asyncio == 0.21.1",
] ]
[tool.hatch.version] [tool.hatch.version]

4
pytest.ini Normal file
View file

@ -0,0 +1,4 @@
[pytest]
DJANGO_SETTINGS_MODULE = app.settings
pythonpath = . src
testpaths = tests

60
src/app/README.md Normal file
View file

@ -0,0 +1,60 @@
# Django App
Khoj uses Django as the backend framework primarily for its powerful ORM and the admin interface. The Django app is located in the `src/app` directory. We have one installed app, under the `/database/` directory. This app is responsible for all the database related operations and holds all of our models. You can find the extensive Django documentation [here](https://docs.djangoproject.com/en/4.2/) 🌈.
## Setup (Docker)
### Prerequisites
1. Ensure you have [Docker](https://docs.docker.com/get-docker/) installed.
2. Ensure you have [Docker Compose](https://docs.docker.com/compose/install/) installed.
### Run
Using the `docker-compose.yml` file in the root directory, you can run the Khoj app using the following command:
```bash
docker-compose up
```
## Setup (Local)
### Install dependencies
```bash
pip install -e '.[dev]'
```
### Setup the database
1. Ensure you have Postgres installed. For MacOS, you can use [Postgres.app](https://postgresapp.com/).
2. If you're not using Postgres.app, you may have to install the pgvector extension manually. You can find the instructions [here](https://github.com/pgvector/pgvector#installation). If you're using Postgres.app, you can skip this step. Reproduced instructions below for convenience.
```bash
cd /tmp
git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git
cd pgvector
make
make install # may need sudo
```
3. Create a database
### Make migrations
This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted).
```bash
python3 src/manage.py makemigrations
```
### Run migrations
This command will run any pending migrations in your application.
```bash
python3 src/manage.py migrate
```
### Run the server
While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend.
```bash
python3 src/khoj/main.py
```

View file

@ -77,8 +77,12 @@ WSGI_APPLICATION = "app.wsgi.application"
DATABASES = { DATABASES = {
"default": { "default": {
"ENGINE": "django.db.backends.sqlite3", "ENGINE": "django.db.backends.postgresql",
"NAME": BASE_DIR / "db.sqlite3", "HOST": os.getenv("POSTGRES_HOST", "localhost"),
"PORT": os.getenv("POSTGRES_PORT", "5432"),
"USER": os.getenv("POSTGRES_USER", "postgres"),
"NAME": os.getenv("POSTGRES_DB", "khoj"),
"PASSWORD": os.getenv("POSTGRES_PASSWORD", "postgres"),
} }
} }

View file

@ -1,15 +1,30 @@
from typing import Type, TypeVar from typing import Type, TypeVar, List
import uuid import uuid
from datetime import date
from django.db import models from django.db import models
from django.contrib.sessions.backends.db import SessionStore from django.contrib.sessions.backends.db import SessionStore
from pgvector.django import CosineDistance
from django.db.models.manager import BaseManager
from django.db.models import Q
from torch import Tensor
# Import sync_to_async from Django Channels # Import sync_to_async from Django Channels
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import HTTPException from fastapi import HTTPException
from database.models import KhojUser, GoogleUser, NotionConfig from database.models import (
KhojUser,
GoogleUser,
NotionConfig,
GithubConfig,
Embeddings,
GithubRepoConfig,
)
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.date_filter import DateFilter
ModelType = TypeVar("ModelType", bound=models.Model) ModelType = TypeVar("ModelType", bound=models.Model)
@ -40,9 +55,7 @@ async def get_or_create_user(token: dict) -> KhojUser:
async def create_google_user(token: dict) -> KhojUser: async def create_google_user(token: dict) -> KhojUser:
user_info = token.get("userinfo") user_info = token.get("userinfo")
user = await KhojUser.objects.acreate( user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
username=user_info.get("email"), email=user_info.get("email"), uuid=uuid.uuid4()
)
await user.asave() await user.asave()
await GoogleUser.objects.acreate( await GoogleUser.objects.acreate(
sub=user_info.get("sub"), sub=user_info.get("sub"),
@ -76,3 +89,149 @@ async def retrieve_user(session_id: str) -> KhojUser:
if not user: if not user:
raise HTTPException(status_code=401, detail="Invalid user") raise HTTPException(status_code=401, detail="Invalid user")
return user return user
def get_all_users() -> BaseManager[KhojUser]:
return KhojUser.objects.all()
def get_user_github_config(user: KhojUser):
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if not config:
return None
return config
def get_user_notion_config(user: KhojUser):
config = NotionConfig.objects.filter(user=user).first()
if not config:
return None
return config
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
await object.objects.filter(user=user).adelete()
await object.objects.acreate(
input_files=deduped_files,
input_filter=deduped_filters,
index_heading_entries=updated_config.index_heading_entries,
user=user,
)
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
config = await GithubConfig.objects.filter(user=user).afirst()
if not config:
config = await GithubConfig.objects.acreate(pat_token=pat_token, user=user)
else:
config.pat_token = pat_token
await config.asave()
await config.githubrepoconfig.all().adelete()
for repo in repos:
await GithubRepoConfig.objects.acreate(
name=repo["name"], owner=repo["owner"], branch=repo["branch"], github_config=config
)
return config
class EmbeddingsAdapters:
word_filer = WordFilter()
file_filter = FileFilter()
date_filter = DateFilter()
@staticmethod
def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool:
return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod
def delete_embedding_by_file(user: KhojUser, file_path: str):
deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete()
return deleted_count
@staticmethod
def delete_all_embeddings(user: KhojUser, file_type: str):
deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete()
return deleted_count
@staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod
def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]):
Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@staticmethod
def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date):
return embeddings.filter(
embeddingsdates__date__gte=start_date,
embeddingsdates__date__lte=end_date,
)
@staticmethod
async def user_has_embeddings(user: KhojUser):
return await Embeddings.objects.filter(user=user).aexists()
@staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
q_filter_terms = Q()
explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query)
file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query)
date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query)
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Embeddings.objects.filter(user=user)
for term in explicit_word_terms:
if term.startswith("+"):
q_filter_terms &= Q(raw__icontains=term[1:])
elif term.startswith("-"):
q_filter_terms &= ~Q(raw__icontains=term[1:])
q_file_filter_terms = Q()
if len(file_filters) > 0:
for term in file_filters:
q_file_filter_terms |= Q(file_path__regex=term)
q_filter_terms &= q_file_filter_terms
if len(date_filters) > 0:
min_date, max_date = date_filters
if min_date is not None:
# Convert the min_date timestamp to yyyy-mm-dd format
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
if max_date is not None:
# Convert the max_date timestamp to yyyy-mm-dd format
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
relevant_embeddings = Embeddings.objects.filter(user=user).filter(
q_filter_terms,
)
if file_type_filter:
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
return relevant_embeddings
@staticmethod
def search_with_embeddings(
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
):
relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_embeddings = relevant_embeddings.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings)
)
if file_type_filter:
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
relevant_embeddings = relevant_embeddings.order_by("distance")
return relevant_embeddings[:max_results]
@staticmethod
def get_unique_file_types(user: KhojUser):
return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct()

View file

@ -1,79 +0,0 @@
# Generated by Django 4.2.5 on 2023-09-27 17:52
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import uuid
class Migration(migrations.Migration):
dependencies = [
("database", "0002_googleuser"),
]
operations = [
migrations.CreateModel(
name="Configuration",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
],
),
migrations.CreateModel(
name="ConversationProcessorConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("conversation", models.JSONField()),
("enable_offline_chat", models.BooleanField(default=False)),
],
),
migrations.CreateModel(
name="GithubConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("pat_token", models.CharField(max_length=200)),
("compressed_jsonl", models.CharField(max_length=300)),
("embeddings_file", models.CharField(max_length=300)),
(
"config",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"),
),
],
),
migrations.AddField(
model_name="khojuser",
name="uuid",
field=models.UUIDField(verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)),
preserve_default=False,
),
migrations.CreateModel(
name="NotionConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("token", models.CharField(max_length=200)),
("compressed_jsonl", models.CharField(max_length=300)),
("embeddings_file", models.CharField(max_length=300)),
(
"config",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"),
),
],
),
migrations.CreateModel(
name="GithubRepoConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("name", models.CharField(max_length=200)),
("owner", models.CharField(max_length=200)),
("branch", models.CharField(max_length=200)),
(
"github_config",
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.githubconfig"),
),
],
),
migrations.AddField(
model_name="configuration",
name="user",
field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
]

View file

@ -0,0 +1,10 @@
from django.db import migrations
from pgvector.django import VectorExtension
class Migration(migrations.Migration):
dependencies = [
("database", "0002_googleuser"),
]
operations = [VectorExtension()]

View file

@ -0,0 +1,193 @@
# Generated by Django 4.2.5 on 2023-10-11 22:24
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import pgvector.django
import uuid
class Migration(migrations.Migration):
dependencies = [
("database", "0003_vector_extension"),
]
operations = [
migrations.CreateModel(
name="ConversationProcessorConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("conversation", models.JSONField()),
("enable_offline_chat", models.BooleanField(default=False)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="GithubConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("pat_token", models.CharField(max_length=200)),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="khojuser",
name="uuid",
field=models.UUIDField(default=1234, verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)),
preserve_default=False,
),
migrations.CreateModel(
name="NotionConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("token", models.CharField(max_length=200)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="LocalPlaintextConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("input_files", models.JSONField(default=list, null=True)),
("input_filter", models.JSONField(default=list, null=True)),
("index_heading_entries", models.BooleanField(default=False)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="LocalPdfConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("input_files", models.JSONField(default=list, null=True)),
("input_filter", models.JSONField(default=list, null=True)),
("index_heading_entries", models.BooleanField(default=False)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="LocalOrgConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("input_files", models.JSONField(default=list, null=True)),
("input_filter", models.JSONField(default=list, null=True)),
("index_heading_entries", models.BooleanField(default=False)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="LocalMarkdownConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("input_files", models.JSONField(default=list, null=True)),
("input_filter", models.JSONField(default=list, null=True)),
("index_heading_entries", models.BooleanField(default=False)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="GithubRepoConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(max_length=200)),
("owner", models.CharField(max_length=200)),
("branch", models.CharField(max_length=200)),
(
"github_config",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="githubrepoconfig",
to="database.githubconfig",
),
),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="githubconfig",
name="user",
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
migrations.CreateModel(
name="Embeddings",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("embeddings", pgvector.django.VectorField(dimensions=384)),
("raw", models.TextField()),
("compiled", models.TextField()),
("heading", models.CharField(blank=True, default=None, max_length=1000, null=True)),
(
"file_type",
models.CharField(
choices=[
("image", "Image"),
("pdf", "Pdf"),
("plaintext", "Plaintext"),
("markdown", "Markdown"),
("org", "Org"),
("notion", "Notion"),
("github", "Github"),
("conversation", "Conversation"),
],
default="plaintext",
max_length=30,
),
),
("file_path", models.CharField(blank=True, default=None, max_length=400, null=True)),
("file_name", models.CharField(blank=True, default=None, max_length=400, null=True)),
("url", models.URLField(blank=True, default=None, max_length=400, null=True)),
("hashed_value", models.CharField(max_length=100)),
(
"user",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
},
),
]

View file

@ -0,0 +1,18 @@
# Generated by Django 4.2.5 on 2023-10-13 02:39
from django.db import migrations, models
import uuid
class Migration(migrations.Migration):
dependencies = [
("database", "0004_conversationprocessorconfig_githubconfig_and_more"),
]
operations = [
migrations.AddField(
model_name="embeddings",
name="corpus_id",
field=models.UUIDField(default=uuid.uuid4, editable=False),
),
]

View file

@ -0,0 +1,33 @@
# Generated by Django 4.2.5 on 2023-10-13 19:28
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0005_embeddings_corpus_id"),
]
operations = [
migrations.CreateModel(
name="EmbeddingsDates",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("date", models.DateField()),
(
"embeddings",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="embeddings_dates",
to="database.embeddings",
),
),
],
options={
"indexes": [models.Index(fields=["date"], name="database_em_date_a1ba47_idx")],
},
),
]

View file

@ -2,11 +2,25 @@ import uuid
from django.db import models from django.db import models
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
from pgvector.django import VectorField
class BaseModel(models.Model):
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
abstract = True
class KhojUser(AbstractUser): class KhojUser(AbstractUser):
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False)) uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
def save(self, *args, **kwargs):
if not self.uuid:
self.uuid = uuid.uuid4()
super().save(*args, **kwargs)
class GoogleUser(models.Model): class GoogleUser(models.Model):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
@ -23,31 +37,85 @@ class GoogleUser(models.Model):
return self.name return self.name
class Configuration(models.Model): class NotionConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
class NotionConfig(models.Model):
token = models.CharField(max_length=200) token = models.CharField(max_length=200)
compressed_jsonl = models.CharField(max_length=300) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
embeddings_file = models.CharField(max_length=300)
config = models.OneToOneField(Configuration, on_delete=models.CASCADE)
class GithubConfig(models.Model): class GithubConfig(BaseModel):
pat_token = models.CharField(max_length=200) pat_token = models.CharField(max_length=200)
compressed_jsonl = models.CharField(max_length=300) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
embeddings_file = models.CharField(max_length=300)
config = models.OneToOneField(Configuration, on_delete=models.CASCADE)
class GithubRepoConfig(models.Model): class GithubRepoConfig(BaseModel):
name = models.CharField(max_length=200) name = models.CharField(max_length=200)
owner = models.CharField(max_length=200) owner = models.CharField(max_length=200)
branch = models.CharField(max_length=200) branch = models.CharField(max_length=200)
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE) github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
class ConversationProcessorConfig(models.Model): class LocalOrgConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalMarkdownConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalPdfConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalPlaintextConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class ConversationProcessorConfig(BaseModel):
conversation = models.JSONField() conversation = models.JSONField()
enable_offline_chat = models.BooleanField(default=False) enable_offline_chat = models.BooleanField(default=False)
class Embeddings(BaseModel):
class EmbeddingsType(models.TextChoices):
IMAGE = "image"
PDF = "pdf"
PLAINTEXT = "plaintext"
MARKDOWN = "markdown"
ORG = "org"
NOTION = "notion"
GITHUB = "github"
CONVERSATION = "conversation"
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
embeddings = VectorField(dimensions=384)
raw = models.TextField()
compiled = models.TextField()
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
file_type = models.CharField(max_length=30, choices=EmbeddingsType.choices, default=EmbeddingsType.PLAINTEXT)
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
url = models.URLField(max_length=400, default=None, null=True, blank=True)
hashed_value = models.CharField(max_length=100)
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
class EmbeddingsDates(BaseModel):
date = models.DateField()
embeddings = models.ForeignKey(Embeddings, on_delete=models.CASCADE, related_name="embeddings_dates")
class Meta:
indexes = [
models.Index(fields=["date"]),
]

3
src/database/tests.py Normal file
View file

@ -0,0 +1,3 @@
from django.test import TestCase
# Create your tests here.

View file

@ -30,6 +30,8 @@ from khoj.utils.helpers import resolve_absolute_path, merge_dicts
from khoj.utils.fs_syncer import collect_files from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
from khoj.routers.indexer import configure_content, load_content, configure_search from khoj.routers.indexer import configure_content, load_content, configure_search
from database.models import KhojUser
from database.adapters import get_all_users
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,14 +50,28 @@ class UserAuthenticationBackend(AuthenticationBackend):
from database.models import KhojUser from database.models import KhojUser
self.khojuser_manager = KhojUser.objects self.khojuser_manager = KhojUser.objects
self._initialize_default_user()
super().__init__() super().__init__()
def _initialize_default_user(self):
if not self.khojuser_manager.filter(username="default").exists():
self.khojuser_manager.create_user(
username="default",
email="default@example.com",
password="default",
)
async def authenticate(self, request): async def authenticate(self, request):
current_user = request.session.get("user") current_user = request.session.get("user")
if current_user and current_user.get("email"): if current_user and current_user.get("email"):
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst() user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
if user: if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
elif not state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").afirst()
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
return AuthCredentials(), UnauthenticatedUser() return AuthCredentials(), UnauthenticatedUser()
@ -78,7 +94,11 @@ def initialize_server(config: Optional[FullConfig]):
def configure_server( def configure_server(
config: FullConfig, regenerate: bool = False, search_type: Optional[SearchType] = None, init=False config: FullConfig,
regenerate: bool = False,
search_type: Optional[SearchType] = None,
init=False,
user: KhojUser = None,
): ):
# Update Config # Update Config
state.config = config state.config = config
@ -95,7 +115,7 @@ def configure_server(
state.config_lock.acquire() state.config_lock.acquire()
state.SearchType = configure_search_types(state.config) state.SearchType = configure_search_types(state.config)
state.search_models = configure_search(state.search_models, state.config.search_type) state.search_models = configure_search(state.search_models, state.config.search_type)
initialize_content(regenerate, search_type, init) initialize_content(regenerate, search_type, init, user)
except Exception as e: except Exception as e:
logger.error(f"🚨 Failed to configure search models", exc_info=True) logger.error(f"🚨 Failed to configure search models", exc_info=True)
raise e raise e
@ -103,7 +123,7 @@ def configure_server(
state.config_lock.release() state.config_lock.release()
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False): def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
# Initialize Content from Config # Initialize Content from Config
if state.search_models: if state.search_models:
try: try:
@ -112,7 +132,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
state.content_index = load_content(state.config.content_type, state.content_index, state.search_models) state.content_index = load_content(state.config.content_type, state.content_index, state.search_models)
else: else:
logger.info("📬 Updating content index...") logger.info("📬 Updating content index...")
all_files = collect_files(state.config.content_type) all_files = collect_files(user=user)
state.content_index = configure_content( state.content_index = configure_content(
state.content_index, state.content_index,
state.config.content_type, state.config.content_type,
@ -120,6 +140,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
state.search_models, state.search_models,
regenerate, regenerate,
search_type, search_type,
user=user,
) )
except Exception as e: except Exception as e:
logger.error(f"🚨 Failed to index content", exc_info=True) logger.error(f"🚨 Failed to index content", exc_info=True)
@ -152,9 +173,14 @@ if not state.demo:
def update_search_index(): def update_search_index():
try: try:
logger.info("📬 Updating content index via Scheduler") logger.info("📬 Updating content index via Scheduler")
all_files = collect_files(state.config.content_type) for user in get_all_users():
all_files = collect_files(user=user)
state.content_index = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models, user=user
)
all_files = collect_files(user=None)
state.content_index = configure_content( state.content_index = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models state.content_index, state.config.content_type, all_files, state.search_models, user=None
) )
logger.info("📪 Content index updated via Scheduler") logger.info("📪 Content index updated via Scheduler")
except Exception as e: except Exception as e:
@ -164,13 +190,9 @@ if not state.demo:
def configure_search_types(config: FullConfig): def configure_search_types(config: FullConfig):
# Extract core search types # Extract core search types
core_search_types = {e.name: e.value for e in SearchType} core_search_types = {e.name: e.value for e in SearchType}
# Extract configured plugin search types
plugin_search_types = {}
if config.content_type and config.content_type.plugins:
plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()}
# Dynamically generate search type enum by merging core search types with configured plugin search types # Dynamically generate search type enum by merging core search types with configured plugin search types
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types)) return Enum("SearchType", merge_dicts(core_search_types, {}))
def configure_processor( def configure_processor(

View file

@ -10,12 +10,10 @@
<img class="card-icon" src="/static/assets/icons/github.svg" alt="Github"> <img class="card-icon" src="/static/assets/icons/github.svg" alt="Github">
<h3 class="card-title"> <h3 class="card-title">
Github Github
{% if current_config.content_type.github %} {% if current_model_state.github == False %}
{% if current_model_state.github == False %} <img id="misconfigured-icon-github" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
<img id="misconfigured-icon-github" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure."> {% else %}
{% else %} <img id="configured-icon-github" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
<img id="configured-icon-github" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% endif %}
{% endif %} {% endif %}
</h3> </h3>
</div> </div>
@ -24,7 +22,7 @@
</div> </div>
<div class="card-action-row"> <div class="card-action-row">
<a class="card-button" href="/config/content_type/github"> <a class="card-button" href="/config/content_type/github">
{% if current_config.content_type.github %} {% if current_model_state.github %}
Update Update
{% else %} {% else %}
Setup Setup
@ -32,7 +30,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg> <svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a> </a>
</div> </div>
{% if current_config.content_type.github %} {% if current_model_state.github %}
<div id="clear-github" class="card-action-row"> <div id="clear-github" class="card-action-row">
<button class="card-button" onclick="clearContentType('github')"> <button class="card-button" onclick="clearContentType('github')">
Disable Disable
@ -45,12 +43,10 @@
<img class="card-icon" src="/static/assets/icons/notion.svg" alt="Notion"> <img class="card-icon" src="/static/assets/icons/notion.svg" alt="Notion">
<h3 class="card-title"> <h3 class="card-title">
Notion Notion
{% if current_config.content_type.notion %} {% if current_model_state.notion == False %}
{% if current_model_state.notion == False %} <img id="misconfigured-icon-notion" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
<img id="misconfigured-icon-notion" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure."> {% else %}
{% else %} <img id="configured-icon-notion" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
<img id="configured-icon-notion" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% endif %}
{% endif %} {% endif %}
</h3> </h3>
</div> </div>
@ -59,7 +55,7 @@
</div> </div>
<div class="card-action-row"> <div class="card-action-row">
<a class="card-button" href="/config/content_type/notion"> <a class="card-button" href="/config/content_type/notion">
{% if current_config.content_type.content %} {% if current_model_state.content %}
Update Update
{% else %} {% else %}
Setup Setup
@ -67,7 +63,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg> <svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a> </a>
</div> </div>
{% if current_config.content_type.notion %} {% if current_model_state.notion %}
<div id="clear-notion" class="card-action-row"> <div id="clear-notion" class="card-action-row">
<button class="card-button" onclick="clearContentType('notion')"> <button class="card-button" onclick="clearContentType('notion')">
Disable Disable
@ -80,7 +76,7 @@
<img class="card-icon" src="/static/assets/icons/markdown.svg" alt="markdown"> <img class="card-icon" src="/static/assets/icons/markdown.svg" alt="markdown">
<h3 class="card-title"> <h3 class="card-title">
Markdown Markdown
{% if current_config.content_type.markdown %} {% if current_model_state.markdown %}
{% if current_model_state.markdown == False%} {% if current_model_state.markdown == False%}
<img id="misconfigured-icon-markdown" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure."> <img id="misconfigured-icon-markdown" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
{% else %} {% else %}
@ -94,7 +90,7 @@
</div> </div>
<div class="card-action-row"> <div class="card-action-row">
<a class="card-button" href="/config/content_type/markdown"> <a class="card-button" href="/config/content_type/markdown">
{% if current_config.content_type.markdown %} {% if current_model_state.markdown %}
Update Update
{% else %} {% else %}
Setup Setup
@ -102,7 +98,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg> <svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a> </a>
</div> </div>
{% if current_config.content_type.markdown %} {% if current_model_state.markdown %}
<div id="clear-markdown" class="card-action-row"> <div id="clear-markdown" class="card-action-row">
<button class="card-button" onclick="clearContentType('markdown')"> <button class="card-button" onclick="clearContentType('markdown')">
Disable Disable
@ -115,7 +111,7 @@
<img class="card-icon" src="/static/assets/icons/org.svg" alt="org"> <img class="card-icon" src="/static/assets/icons/org.svg" alt="org">
<h3 class="card-title"> <h3 class="card-title">
Org Org
{% if current_config.content_type.org %} {% if current_model_state.org %}
{% if current_model_state.org == False %} {% if current_model_state.org == False %}
<img id="misconfigured-icon-org" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure."> <img id="misconfigured-icon-org" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
{% else %} {% else %}
@ -129,7 +125,7 @@
</div> </div>
<div class="card-action-row"> <div class="card-action-row">
<a class="card-button" href="/config/content_type/org"> <a class="card-button" href="/config/content_type/org">
{% if current_config.content_type.org %} {% if current_model_state.org %}
Update Update
{% else %} {% else %}
Setup Setup
@ -137,7 +133,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg> <svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a> </a>
</div> </div>
{% if current_config.content_type.org %} {% if current_model_state.org %}
<div id="clear-org" class="card-action-row"> <div id="clear-org" class="card-action-row">
<button class="card-button" onclick="clearContentType('org')"> <button class="card-button" onclick="clearContentType('org')">
Disable Disable
@ -150,7 +146,7 @@
<img class="card-icon" src="/static/assets/icons/pdf.svg" alt="PDF"> <img class="card-icon" src="/static/assets/icons/pdf.svg" alt="PDF">
<h3 class="card-title"> <h3 class="card-title">
PDF PDF
{% if current_config.content_type.pdf %} {% if current_model_state.pdf %}
{% if current_model_state.pdf == False %} {% if current_model_state.pdf == False %}
<img id="misconfigured-icon-pdf" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you need to click Configure."> <img id="misconfigured-icon-pdf" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you need to click Configure.">
{% else %} {% else %}
@ -164,7 +160,7 @@
</div> </div>
<div class="card-action-row"> <div class="card-action-row">
<a class="card-button" href="/config/content_type/pdf"> <a class="card-button" href="/config/content_type/pdf">
{% if current_config.content_type.pdf %} {% if current_model_state.pdf %}
Update Update
{% else %} {% else %}
Setup Setup
@ -172,7 +168,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg> <svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a> </a>
</div> </div>
{% if current_config.content_type.pdf %} {% if current_model_state.pdf %}
<div id="clear-pdf" class="card-action-row"> <div id="clear-pdf" class="card-action-row">
<button class="card-button" onclick="clearContentType('pdf')"> <button class="card-button" onclick="clearContentType('pdf')">
Disable Disable
@ -185,7 +181,7 @@
<img class="card-icon" src="/static/assets/icons/plaintext.svg" alt="Plaintext"> <img class="card-icon" src="/static/assets/icons/plaintext.svg" alt="Plaintext">
<h3 class="card-title"> <h3 class="card-title">
Plaintext Plaintext
{% if current_config.content_type.plaintext %} {% if current_model_state.plaintext %}
{% if current_model_state.plaintext == False %} {% if current_model_state.plaintext == False %}
<img id="misconfigured-icon-plaintext" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you need to click Configure."> <img id="misconfigured-icon-plaintext" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you need to click Configure.">
{% else %} {% else %}
@ -199,7 +195,7 @@
</div> </div>
<div class="card-action-row"> <div class="card-action-row">
<a class="card-button" href="/config/content_type/plaintext"> <a class="card-button" href="/config/content_type/plaintext">
{% if current_config.content_type.plaintext %} {% if current_model_state.plaintext %}
Update Update
{% else %} {% else %}
Setup Setup
@ -207,7 +203,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg> <svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a> </a>
</div> </div>
{% if current_config.content_type.plaintext %} {% if current_model_state.plaintext %}
<div id="clear-plaintext" class="card-action-row"> <div id="clear-plaintext" class="card-action-row">
<button class="card-button" onclick="clearContentType('plaintext')"> <button class="card-button" onclick="clearContentType('plaintext')">
Disable Disable

View file

@ -38,24 +38,6 @@
{% endfor %} {% endfor %}
</div> </div>
<button type="button" id="add-repository-button">Add Repository</button> <button type="button" id="add-repository-button">Add Repository</button>
<table style="display: none;" >
<tr>
<td>
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
</td>
<td>
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
</td>
</tr>
<tr>
<td>
<label for="embeddings-file">Embeddings File (Output)</label>
</td>
<td>
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
</td>
</tr>
</table>
<div class="section"> <div class="section">
<div id="success" style="display: none;"></div> <div id="success" style="display: none;"></div>
<button id="submit" type="submit">Save</button> <button id="submit" type="submit">Save</button>
@ -107,8 +89,6 @@
submit.addEventListener("click", function(event) { submit.addEventListener("click", function(event) {
event.preventDefault(); event.preventDefault();
const compressed_jsonl = document.getElementById("compressed-jsonl").value;
const embeddings_file = document.getElementById("embeddings-file").value;
const pat_token = document.getElementById("pat-token").value; const pat_token = document.getElementById("pat-token").value;
if (pat_token == "") { if (pat_token == "") {
@ -154,8 +134,6 @@
body: JSON.stringify({ body: JSON.stringify({
"pat_token": pat_token, "pat_token": pat_token,
"repos": repos, "repos": repos,
"compressed_jsonl": compressed_jsonl,
"embeddings_file": embeddings_file,
}) })
}) })
.then(response => response.json()) .then(response => response.json())

View file

@ -43,33 +43,6 @@
</td> </td>
</tr> </tr>
</table> </table>
<table style="display: none;" >
<tr>
<td>
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
</td>
<td>
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
</td>
</tr>
<tr>
<td>
<label for="embeddings-file">Embeddings File (Output)</label>
</td>
<td>
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
</td>
</tr>
<tr>
<td>
<label for="index-heading-entries">Index Heading Entries</label>
</td>
<td>
<input type="text" id="index-heading-entries" name="index-heading-entries" value="{{ current_config['index_heading_entries'] }}">
</td>
</tr>
</table>
<div class="section"> <div class="section">
<div id="success" style="display: none;" ></div> <div id="success" style="display: none;" ></div>
<button id="submit" type="submit">Save</button> <button id="submit" type="submit">Save</button>
@ -155,9 +128,8 @@
inputFilter = null; inputFilter = null;
} }
var compressed_jsonl = document.getElementById("compressed-jsonl").value; // var index_heading_entries = document.getElementById("index-heading-entries").value;
var embeddings_file = document.getElementById("embeddings-file").value; var index_heading_entries = true;
var index_heading_entries = document.getElementById("index-heading-entries").value;
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/content_type/{{ content_type }}', { fetch('/api/config/data/content_type/{{ content_type }}', {
@ -169,8 +141,6 @@
body: JSON.stringify({ body: JSON.stringify({
"input_files": inputFiles, "input_files": inputFiles,
"input_filter": inputFilter, "input_filter": inputFilter,
"compressed_jsonl": compressed_jsonl,
"embeddings_file": embeddings_file,
"index_heading_entries": index_heading_entries "index_heading_entries": index_heading_entries
}) })
}) })

View file

@ -20,24 +20,6 @@
</td> </td>
</tr> </tr>
</table> </table>
<table style="display: none;" >
<tr>
<td>
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
</td>
<td>
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
</td>
</tr>
<tr>
<td>
<label for="embeddings-file">Embeddings File (Output)</label>
</td>
<td>
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
</td>
</tr>
</table>
<div class="section"> <div class="section">
<div id="success" style="display: none;"></div> <div id="success" style="display: none;"></div>
<button id="submit" type="submit">Save</button> <button id="submit" type="submit">Save</button>
@ -51,8 +33,6 @@
submit.addEventListener("click", function(event) { submit.addEventListener("click", function(event) {
event.preventDefault(); event.preventDefault();
const compressed_jsonl = document.getElementById("compressed-jsonl").value;
const embeddings_file = document.getElementById("embeddings-file").value;
const token = document.getElementById("token").value; const token = document.getElementById("token").value;
if (token == "") { if (token == "") {
@ -70,8 +50,6 @@
}, },
body: JSON.stringify({ body: JSON.stringify({
"token": token, "token": token,
"compressed_jsonl": compressed_jsonl,
"embeddings_file": embeddings_file,
}) })
}) })
.then(response => response.json()) .then(response => response.json())

View file

@ -172,7 +172,7 @@
url = createRequestUrl(query, type, results_count || 5, rerank); url = createRequestUrl(query, type, results_count || 5, rerank);
fetch(url, { fetch(url, {
headers: { headers: {
"X-CSRFToken": csrfToken "Content-Type": "application/json"
} }
}) })
.then(response => response.json()) .then(response => response.json())
@ -199,8 +199,8 @@
fetch("/api/config/types") fetch("/api/config/types")
.then(response => response.json()) .then(response => response.json())
.then(enabled_types => { .then(enabled_types => {
// Show warning if no content types are enabled // Show warning if no content types are enabled, or just one ("all")
if (enabled_types.detail) { if (enabled_types[0] === "all" && enabled_types.length === 1) {
document.getElementById("results").innerHTML = "<div id='results-error'>To use Khoj search, setup your content plugins on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>.</div>"; document.getElementById("results").innerHTML = "<div id='results-error'>To use Khoj search, setup your content plugins on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>.</div>";
document.getElementById("query").setAttribute("disabled", "disabled"); document.getElementById("query").setAttribute("disabled", "disabled");
document.getElementById("query").setAttribute("placeholder", "Configure Khoj to enable search"); document.getElementById("query").setAttribute("placeholder", "Configure Khoj to enable search");

View file

@ -24,10 +24,15 @@ from rich.logging import RichHandler
from django.core.asgi import get_asgi_application from django.core.asgi import get_asgi_application
from django.core.management import call_command from django.core.management import call_command
# Internal Packages # Initialize Django
from khoj.configure import configure_routes, initialize_server, configure_middleware os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
from khoj.utils import state django.setup()
from khoj.utils.cli import cli
# Initialize Django Database
call_command("migrate", "--noinput")
# Initialize Django Static Files
call_command("collectstatic", "--noinput")
# Initialize Django # Initialize Django
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings") os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
@ -54,6 +59,11 @@ app.add_middleware(
# Set Locale # Set Locale
locale.setlocale(locale.LC_ALL, "") locale.setlocale(locale.LC_ALL, "")
# Internal Packages. We do this after setting up Django so that Django features are accessible to the app.
from khoj.configure import configure_routes, initialize_server, configure_middleware
from khoj.utils import state
from khoj.utils.cli import cli
# Setup Logger # Setup Logger
rich_handler = RichHandler(rich_tracebacks=True) rich_handler = RichHandler(rich_tracebacks=True)
rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]")) rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
@ -95,6 +105,8 @@ def run():
# Mount Django and Static Files # Mount Django and Static Files
app.mount("/django", django_app, name="django") app.mount("/django", django_app, name="django")
if not os.path.exists("static"):
os.mkdir("static")
app.mount("/static", StaticFiles(directory="static"), name="static") app.mount("/static", StaticFiles(directory="static"), name="static")
# Configure Middleware # Configure Middleware
@ -111,6 +123,7 @@ def set_state(args):
state.host = args.host state.host = args.host
state.port = args.port state.port = args.port
state.demo = args.demo state.demo = args.demo
state.anonymous_mode = args.anonymous_mode
state.khoj_version = version("khoj-assistant") state.khoj_version = version("khoj-assistant")

View file

@ -0,0 +1,57 @@
from typing import List
import torch
from langchain.embeddings import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel:
def __init__(self):
self.model_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
encode_kwargs = {"normalize_embeddings": True}
# encode_kwargs = {}
if torch.cuda.is_available():
# Use CUDA GPU
device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
# Use Apple M1 Metal Acceleration
device = torch.device("mps")
else:
device = torch.device("cpu")
model_kwargs = {"device": device}
self.embeddings_model = HuggingFaceEmbeddings(
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
)
def embed_query(self, query):
return self.embeddings_model.embed_query(query)
def embed_documents(self, docs):
return self.embeddings_model.embed_documents(docs)
class CrossEncoderModel:
def __init__(self):
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
if torch.cuda.is_available():
# Use CUDA GPU
device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
# Use Apple M1 Metal Acceleration
device = torch.device("mps")
else:
device = torch.device("cpu")
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=device)
def predict(self, query, hits: List[SearchResponse]):
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
cross_scores = self.cross_encoder_model.predict(cross__inp)
return cross_scores

View file

@ -2,7 +2,7 @@
import logging import logging
import time import time
from datetime import datetime from datetime import datetime
from typing import Dict, List, Union from typing import Dict, List, Union, Tuple
# External Packages # External Packages
import requests import requests
@ -12,18 +12,31 @@ from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.text_to_jsonl import TextToJsonl from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from database.models import Embeddings, GithubConfig, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GithubToJsonl(TextToJsonl): class GithubToJsonl(TextEmbeddings):
def __init__(self, config: GithubContentConfig): def __init__(self, config: GithubConfig):
super().__init__(config) super().__init__(config)
self.config = config raw_repos = config.githubrepoconfig.all()
repos = []
for repo in raw_repos:
repos.append(
GithubRepoConfig(
name=repo.name,
owner=repo.owner,
branch=repo.branch,
)
)
self.config = GithubContentConfig(
pat_token=config.pat_token,
repos=repos,
)
self.session = requests.Session() self.session = requests.Session()
self.session.headers.update({"Authorization": f"token {self.config.pat_token}"}) self.session.headers.update({"Authorization": f"token {self.config.pat_token}"})
@ -37,7 +50,9 @@ class GithubToJsonl(TextToJsonl):
else: else:
return return
def process(self, previous_entries=[], files=None, full_corpus=True): def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
if self.config.pat_token is None or self.config.pat_token == "": if self.config.pat_token is None or self.config.pat_token == "":
logger.error(f"Github PAT token is not set. Skipping github content") logger.error(f"Github PAT token is not set. Skipping github content")
raise ValueError("Github PAT token is not set. Skipping github content") raise ValueError("Github PAT token is not set. Skipping github content")
@ -45,7 +60,7 @@ class GithubToJsonl(TextToJsonl):
for repo in self.config.repos: for repo in self.config.repos:
current_entries += self.process_repo(repo) current_entries += self.process_repo(repo)
return self.update_entries_with_ids(current_entries, previous_entries) return self.update_entries_with_ids(current_entries, user=user)
def process_repo(self, repo: GithubRepoConfig): def process_repo(self, repo: GithubRepoConfig):
repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}" repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}"
@ -80,26 +95,18 @@ class GithubToJsonl(TextToJsonl):
current_entries += issue_entries current_entries += issue_entries
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger): with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
current_entries = TextToJsonl.split_entries_by_max_tokens(current_entries, max_tokens=256) current_entries = TextEmbeddings.split_entries_by_max_tokens(current_entries, max_tokens=256)
return current_entries return current_entries
def update_entries_with_ids(self, current_entries, previous_entries): def update_entries_with_ids(self, current_entries, user: KhojUser = None):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, previous_entries, key="compiled", logger=logger current_entries, Embeddings.EmbeddingsType.GITHUB, key="compiled", logger=logger, user=user
) )
with timer("Write github entries to JSONL file", logger): return num_new_embeddings, num_deleted_embeddings
# Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
return entries_with_ids
def get_files(self, repo_url: str, repo: GithubRepoConfig): def get_files(self, repo_url: str, repo: GithubRepoConfig):
# Get the contents of the repository # Get the contents of the repository

View file

@ -1,91 +0,0 @@
# Standard Packages
import glob
import logging
from pathlib import Path
from typing import List
# Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.helpers import get_absolute_path, timer
from khoj.utils.jsonl import load_jsonl, compress_jsonl_data
from khoj.utils.rawconfig import Entry
logger = logging.getLogger(__name__)
class JsonlToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=[], files: dict[str, str] = {}, full_corpus: bool = True):
# Extract required fields from config
input_jsonl_files, input_jsonl_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
# Get Jsonl Input Files to Process
all_input_jsonl_files = JsonlToJsonl.get_jsonl_files(input_jsonl_files, input_jsonl_filter)
# Extract Entries from specified jsonl files
with timer("Parse entries from jsonl files", logger):
input_jsons = JsonlToJsonl.extract_jsonl_entries(all_input_jsonl_files)
current_entries = list(map(Entry.from_dict, input_jsons))
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
with timer("Write entries to JSONL file", logger):
# Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = JsonlToJsonl.convert_entries_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(jsonl_data, output_file)
return entries_with_ids
@staticmethod
def get_jsonl_files(jsonl_files=None, jsonl_file_filters=None):
"Get all jsonl files to process"
absolute_jsonl_files, filtered_jsonl_files = set(), set()
if jsonl_files:
absolute_jsonl_files = {get_absolute_path(jsonl_file) for jsonl_file in jsonl_files}
if jsonl_file_filters:
filtered_jsonl_files = {
filtered_file
for jsonl_file_filter in jsonl_file_filters
for filtered_file in glob.glob(get_absolute_path(jsonl_file_filter), recursive=True)
}
all_jsonl_files = sorted(absolute_jsonl_files | filtered_jsonl_files)
files_with_non_jsonl_extensions = {
jsonl_file for jsonl_file in all_jsonl_files if not jsonl_file.endswith(".jsonl")
}
if any(files_with_non_jsonl_extensions):
print(f"[Warning] There maybe non jsonl files in the input set: {files_with_non_jsonl_extensions}")
logger.debug(f"Processing files: {all_jsonl_files}")
return all_jsonl_files
@staticmethod
def extract_jsonl_entries(jsonl_files):
"Extract entries from specified jsonl files"
entries = []
for jsonl_file in jsonl_files:
entries.extend(load_jsonl(Path(jsonl_file)))
return entries
@staticmethod
def convert_entries_to_jsonl(entries: List[Entry]):
"Convert each entry to JSON and collate as JSONL"
return "".join([f"{entry.to_json()}\n" for entry in entries])

View file

@ -3,29 +3,28 @@ import logging
import re import re
import urllib3 import urllib3
from pathlib import Path from pathlib import Path
from typing import List from typing import Tuple, List
# Internal Packages # Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.constants import empty_escape_sequences from khoj.utils.constants import empty_escape_sequences
from khoj.utils.jsonl import compress_jsonl_data from khoj.utils.rawconfig import Entry
from khoj.utils.rawconfig import Entry, TextContentConfig from database.models import Embeddings, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MarkdownToJsonl(TextToJsonl): class MarkdownToJsonl(TextEmbeddings):
def __init__(self, config: TextContentConfig): def __init__(self):
super().__init__(config) super().__init__()
self.config = config
# Define Functions # Define Functions
def process(self, previous_entries=[], files=None, full_corpus: bool = True): def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
output_file = self.config.compressed_jsonl
if not full_corpus: if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""]) deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
@ -45,19 +44,17 @@ class MarkdownToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names current_entries,
Embeddings.EmbeddingsType.MARKDOWN,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
) )
with timer("Write markdown entries to JSONL file", logger): return num_new_embeddings, num_deleted_embeddings
# Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(jsonl_data, output_file)
return entries_with_ids
@staticmethod @staticmethod
def extract_markdown_entries(markdown_files): def extract_markdown_entries(markdown_files):

View file

@ -1,5 +1,6 @@
# Standard Packages # Standard Packages
import logging import logging
from typing import Tuple
# External Packages # External Packages
import requests import requests
@ -7,9 +8,9 @@ import requests
# Internal Packages # Internal Packages
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, NotionContentConfig from khoj.utils.rawconfig import Entry, NotionContentConfig
from khoj.processor.text_to_jsonl import TextToJsonl from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser, NotionConfig
from enum import Enum from enum import Enum
@ -49,10 +50,12 @@ class NotionBlockType(Enum):
CALLOUT = "callout" CALLOUT = "callout"
class NotionToJsonl(TextToJsonl): class NotionToJsonl(TextEmbeddings):
def __init__(self, config: NotionContentConfig): def __init__(self, config: NotionConfig):
super().__init__(config) super().__init__(config)
self.config = config self.config = NotionContentConfig(
token=config.token,
)
self.session = requests.Session() self.session = requests.Session()
self.session.headers.update({"Authorization": f"Bearer {config.token}", "Notion-Version": "2022-02-22"}) self.session.headers.update({"Authorization": f"Bearer {config.token}", "Notion-Version": "2022-02-22"})
self.unsupported_block_types = [ self.unsupported_block_types = [
@ -80,7 +83,9 @@ class NotionToJsonl(TextToJsonl):
self.body_params = {"page_size": 100} self.body_params = {"page_size": 100}
def process(self, previous_entries=[], files=None, full_corpus=True): def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
current_entries = [] current_entries = []
# Get all pages # Get all pages
@ -112,7 +117,7 @@ class NotionToJsonl(TextToJsonl):
page_entries = self.process_page(p_or_d) page_entries = self.process_page(p_or_d)
current_entries.extend(page_entries) current_entries.extend(page_entries)
return self.update_entries_with_ids(current_entries, previous_entries) return self.update_entries_with_ids(current_entries, user)
def process_page(self, page): def process_page(self, page):
page_id = page["id"] page_id = page["id"]
@ -241,19 +246,11 @@ class NotionToJsonl(TextToJsonl):
title = None title = None
return title, content return title, content
def update_entries_with_ids(self, current_entries, previous_entries): def update_entries_with_ids(self, current_entries, user: KhojUser = None):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, previous_entries, key="compiled", logger=logger current_entries, Embeddings.EmbeddingsType.NOTION, key="compiled", logger=logger, user=user
) )
with timer("Write Notion entries to JSONL file", logger): return num_new_embeddings, num_deleted_embeddings
# Process Each Entry from all Notion entries
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
return entries_with_ids

View file

@ -5,28 +5,26 @@ from typing import Iterable, List, Tuple
# Internal Packages # Internal Packages
from khoj.processor.org_mode import orgnode from khoj.processor.org_mode import orgnode
from khoj.processor.text_to_jsonl import TextToJsonl from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.jsonl import compress_jsonl_data from khoj.utils.rawconfig import Entry
from khoj.utils.rawconfig import Entry, TextContentConfig
from khoj.utils import state from khoj.utils import state
from database.models import Embeddings, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OrgToJsonl(TextToJsonl): class OrgToJsonl(TextEmbeddings):
def __init__(self, config: TextContentConfig): def __init__(self):
super().__init__(config) super().__init__()
self.config = config
# Define Functions # Define Functions
def process( def process(
self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> List[Tuple[int, Entry]]: ) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
output_file = self.config.compressed_jsonl index_heading_entries = True
index_heading_entries = self.config.index_heading_entries
if not full_corpus: if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""]) deletion_file_names = set([file for file in files if files[file] == ""])
@ -47,19 +45,17 @@ class OrgToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names current_entries,
Embeddings.EmbeddingsType.ORG,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
) )
# Process Each Entry from All Notes Files return num_new_embeddings, num_deleted_embeddings
with timer("Write org entries to JSONL file", logger):
entries = map(lambda entry: entry[1], entries_with_ids)
jsonl_data = self.convert_org_entries_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(jsonl_data, output_file)
return entries_with_ids
@staticmethod @staticmethod
def extract_org_entries(org_files: dict[str, str]): def extract_org_entries(org_files: dict[str, str]):

View file

@ -1,28 +1,31 @@
# Standard Packages # Standard Packages
import os import os
import logging import logging
from typing import List from typing import List, Tuple
import base64 import base64
# External Packages # External Packages
from langchain.document_loaders import PyMuPDFLoader from langchain.document_loaders import PyMuPDFLoader
# Internal Packages # Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PdfToJsonl(TextToJsonl): class PdfToJsonl(TextEmbeddings):
# Define Functions def __init__(self):
def process(self, previous_entries=[], files: dict[str, str] = None, full_corpus: bool = True): super().__init__()
# Extract required fields from config
output_file = self.config.compressed_jsonl
# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
if not full_corpus: if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""]) deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
@ -40,19 +43,17 @@ class PdfToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names current_entries,
Embeddings.EmbeddingsType.PDF,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
) )
with timer("Write PDF entries to JSONL file", logger): return num_new_embeddings, num_deleted_embeddings
# Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = PdfToJsonl.convert_pdf_maps_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(jsonl_data, output_file)
return entries_with_ids
@staticmethod @staticmethod
def extract_pdf_entries(pdf_files): def extract_pdf_entries(pdf_files):
@ -62,7 +63,7 @@ class PdfToJsonl(TextToJsonl):
entry_to_location_map = [] entry_to_location_map = []
for pdf_file in pdf_files: for pdf_file in pdf_files:
try: try:
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PyPDFLoader expects a file path # Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
tmp_file = f"tmp_pdf_file.pdf" tmp_file = f"tmp_pdf_file.pdf"
with open(f"{tmp_file}", "wb") as f: with open(f"{tmp_file}", "wb") as f:
bytes = pdf_files[pdf_file] bytes = pdf_files[pdf_file]

View file

@ -4,22 +4,23 @@ from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
# Internal Packages # Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.jsonl import compress_jsonl_data from khoj.utils.rawconfig import Entry, TextContentConfig
from khoj.utils.rawconfig import Entry from database.models import Embeddings, KhojUser, LocalPlaintextConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PlaintextToJsonl(TextToJsonl): class PlaintextToJsonl(TextEmbeddings):
def __init__(self):
super().__init__()
# Define Functions # Define Functions
def process( def process(
self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> List[Tuple[int, Entry]]: ) -> Tuple[int, int]:
output_file = self.config.compressed_jsonl
if not full_corpus: if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""]) deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names files_to_process = set(files) - deletion_file_names
@ -37,19 +38,17 @@ class PlaintextToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger): with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update( num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names current_entries,
Embeddings.EmbeddingsType.PLAINTEXT,
key="compiled",
logger=logger,
deletion_filenames=deletion_file_names,
user=user,
regenerate=regenerate,
) )
with timer("Write entries to JSONL file", logger): return num_new_embeddings, num_deleted_embeddings
# Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids))
plaintext_data = PlaintextToJsonl.convert_entries_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(plaintext_data, output_file)
return entries_with_ids
@staticmethod @staticmethod
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]: def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:

View file

@ -2,24 +2,33 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import hashlib import hashlib
import logging import logging
from typing import Callable, List, Tuple, Set import uuid
from tqdm import tqdm
from typing import Callable, List, Tuple, Set, Any
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
# Internal Packages # Internal Packages
from khoj.utils.rawconfig import Entry, TextConfigBase from khoj.utils.rawconfig import Entry
from khoj.processor.embeddings import EmbeddingsModel
from khoj.search_filter.date_filter import DateFilter
from database.models import KhojUser, Embeddings, EmbeddingsDates
from database.adapters import EmbeddingsAdapters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TextToJsonl(ABC): class TextEmbeddings(ABC):
def __init__(self, config: TextConfigBase): def __init__(self, config: Any = None):
self.embeddings_model = EmbeddingsModel()
self.config = config self.config = config
self.date_filter = DateFilter()
@abstractmethod @abstractmethod
def process( def process(
self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> List[Tuple[int, Entry]]: ) -> Tuple[int, int]:
... ...
@staticmethod @staticmethod
@ -38,6 +47,7 @@ class TextToJsonl(ABC):
# Drop long words instead of having entry truncated to maintain quality of entry processed by models # Drop long words instead of having entry truncated to maintain quality of entry processed by models
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length] compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
corpus_id = uuid.uuid4()
# Split entry into chunks of max tokens # Split entry into chunks of max tokens
for chunk_index in range(0, len(compiled_entry_words), max_tokens): for chunk_index in range(0, len(compiled_entry_words), max_tokens):
@ -57,11 +67,103 @@ class TextToJsonl(ABC):
raw=entry.raw, raw=entry.raw,
heading=entry.heading, heading=entry.heading,
file=entry.file, file=entry.file,
corpus_id=corpus_id,
) )
) )
return chunked_entries return chunked_entries
def update_embeddings(
self,
current_entries: List[Entry],
file_type: str,
key="compiled",
logger: logging.Logger = None,
deletion_filenames: Set[str] = None,
user: KhojUser = None,
regenerate: bool = False,
):
with timer("Construct current entry hashes", logger):
hashes_by_file = dict[str, set[str]]()
current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries))
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
for entry in tqdm(current_entries, desc="Hashing Entries"):
hashes_by_file.setdefault(entry.file, set()).add(TextEmbeddings.hash_func(key)(entry))
num_deleted_embeddings = 0
with timer("Preparing dataset for regeneration", logger):
if regenerate:
logger.info(f"Deleting all embeddings for file type {file_type}")
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
num_new_embeddings = 0
with timer("Identify hashes for adding new entries", logger):
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
hashes_for_file = hashes_by_file[file]
hashes_to_process = set()
existing_entries = Embeddings.objects.filter(
user=user, hashed_value__in=hashes_for_file, file_type=file_type
)
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
hashes_to_process = hashes_for_file - existing_entry_hashes
# for hashed_val in hashes_for_file:
# if not EmbeddingsAdapters.does_embedding_exist(user, hashed_val):
# hashes_to_process.add(hashed_val)
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings = self.embeddings_model.embed_documents(data_to_embed)
with timer("Update the database with new vector embeddings", logger):
embeddings_to_create = []
for hashed_val, embedding in zip(hashes_to_process, embeddings):
entry = hash_to_current_entries[hashed_val]
embeddings_to_create.append(
Embeddings(
user=user,
embeddings=embedding,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading,
file_path=entry.file,
file_type=file_type,
hashed_value=hashed_val,
corpus_id=entry.corpus_id,
)
)
new_embeddings = Embeddings.objects.bulk_create(embeddings_to_create)
num_new_embeddings += len(new_embeddings)
dates_to_create = []
with timer("Create new date associations for new embeddings", logger):
for embedding in new_embeddings:
dates = self.date_filter.extract_dates(embedding.raw)
for date in dates:
dates_to_create.append(
EmbeddingsDates(
date=date,
embeddings=embedding,
)
)
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
if len(new_dates) > 0:
logger.info(f"Created {len(new_dates)} new date entries")
with timer("Identify hashes for removed entries", logger):
for file in hashes_by_file:
existing_entry_hashes = EmbeddingsAdapters.get_existing_entry_hashes_by_file(user, file)
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
num_deleted_embeddings += len(to_delete_entry_hashes)
EmbeddingsAdapters.delete_embedding_by_hash(user, hashed_values=list(to_delete_entry_hashes))
with timer("Identify hashes for deleting entries", logger):
if deletion_filenames is not None:
for file_path in deletion_filenames:
deleted_count = EmbeddingsAdapters.delete_embedding_by_file(user, file_path)
num_deleted_embeddings += deleted_count
return num_new_embeddings, num_deleted_embeddings
@staticmethod @staticmethod
def mark_entries_for_update( def mark_entries_for_update(
current_entries: List[Entry], current_entries: List[Entry],
@ -72,11 +174,11 @@ class TextToJsonl(ABC):
): ):
# Hash all current and previous entries to identify new entries # Hash all current and previous entries to identify new entries
with timer("Hash previous, current entries", logger): with timer("Hash previous, current entries", logger):
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries)) current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries))
previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries)) previous_entry_hashes = list(map(TextEmbeddings.hash_func(key), previous_entries))
if deletion_filenames is not None: if deletion_filenames is not None:
deletion_entries = [entry for entry in previous_entries if entry.file in deletion_filenames] deletion_entries = [entry for entry in previous_entries if entry.file in deletion_filenames]
deletion_entry_hashes = list(map(TextToJsonl.hash_func(key), deletion_entries)) deletion_entry_hashes = list(map(TextEmbeddings.hash_func(key), deletion_entries))
else: else:
deletion_entry_hashes = [] deletion_entry_hashes = []

View file

@ -2,14 +2,15 @@
import concurrent.futures import concurrent.futures
import math import math
import time import time
import yaml
import logging import logging
import json import json
from typing import List, Optional, Union, Any from typing import List, Optional, Union, Any
import asyncio
# External Packages # External Packages
from fastapi import APIRouter, HTTPException, Header, Request from fastapi import APIRouter, HTTPException, Header, Request, Depends
from sentence_transformers import util from starlette.authentication import requires
from asgiref.sync import sync_to_async
# Internal Packages # Internal Packages
from khoj.configure import configure_processor, configure_server from khoj.configure import configure_processor, configure_server
@ -20,7 +21,6 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.utils.config import TextSearchModel from khoj.utils.config import TextSearchModel
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
ContentConfig,
FullConfig, FullConfig,
ProcessorConfig, ProcessorConfig,
SearchConfig, SearchConfig,
@ -48,11 +48,74 @@ from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
from fastapi.requests import Request from fastapi.requests import Request
from database import adapters
from database.adapters import EmbeddingsAdapters
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
# Initialize Router # Initialize Router
api = APIRouter() api = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def map_config_to_object(content_type: str):
if content_type == "org":
return LocalOrgConfig
if content_type == "markdown":
return LocalMarkdownConfig
if content_type == "pdf":
return LocalPdfConfig
if content_type == "plaintext":
return LocalPlaintextConfig
async def map_config_to_db(config: FullConfig, user: KhojUser):
if config.content_type:
if config.content_type.org:
await LocalOrgConfig.objects.filter(user=user).adelete()
await LocalOrgConfig.objects.acreate(
input_files=config.content_type.org.input_files,
input_filter=config.content_type.org.input_filter,
index_heading_entries=config.content_type.org.index_heading_entries,
user=user,
)
if config.content_type.markdown:
await LocalMarkdownConfig.objects.filter(user=user).adelete()
await LocalMarkdownConfig.objects.acreate(
input_files=config.content_type.markdown.input_files,
input_filter=config.content_type.markdown.input_filter,
index_heading_entries=config.content_type.markdown.index_heading_entries,
user=user,
)
if config.content_type.pdf:
await LocalPdfConfig.objects.filter(user=user).adelete()
await LocalPdfConfig.objects.acreate(
input_files=config.content_type.pdf.input_files,
input_filter=config.content_type.pdf.input_filter,
index_heading_entries=config.content_type.pdf.index_heading_entries,
user=user,
)
if config.content_type.plaintext:
await LocalPlaintextConfig.objects.filter(user=user).adelete()
await LocalPlaintextConfig.objects.acreate(
input_files=config.content_type.plaintext.input_files,
input_filter=config.content_type.plaintext.input_filter,
index_heading_entries=config.content_type.plaintext.index_heading_entries,
user=user,
)
if config.content_type.github:
await adapters.set_user_github_config(
user=user,
pat_token=config.content_type.github.pat_token,
repos=config.content_type.github.repos,
)
if config.content_type.notion:
await adapters.set_notion_config(
user=user,
token=config.content_type.notion.token,
)
# If it's a demo instance, prevent updating any of the configuration. # If it's a demo instance, prevent updating any of the configuration.
if not state.demo: if not state.demo:
@ -64,7 +127,10 @@ if not state.demo:
state.processor_config = configure_processor(state.config.processor) state.processor_config = configure_processor(state.config.processor)
@api.get("/config/data", response_model=FullConfig) @api.get("/config/data", response_model=FullConfig)
def get_config_data(): def get_config_data(request: Request):
user = request.user.object if request.user.is_authenticated else None
enabled_content = EmbeddingsAdapters.get_unique_file_types(user)
return state.config return state.config
@api.post("/config/data") @api.post("/config/data")
@ -73,20 +139,19 @@ if not state.demo:
updated_config: FullConfig, updated_config: FullConfig,
client: Optional[str] = None, client: Optional[str] = None,
): ):
state.config = updated_config user = request.user.object if request.user.is_authenticated else None
with open(state.config_file, "w") as outfile: await map_config_to_db(updated_config, user)
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
outfile.close()
configuration_update_metadata = dict() configuration_update_metadata = {}
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
if state.config.content_type is not None: if state.config.content_type is not None:
configuration_update_metadata["github"] = state.config.content_type.github is not None configuration_update_metadata["github"] = "github" in enabled_content
configuration_update_metadata["notion"] = state.config.content_type.notion is not None configuration_update_metadata["notion"] = "notion" in enabled_content
configuration_update_metadata["org"] = state.config.content_type.org is not None configuration_update_metadata["org"] = "org" in enabled_content
configuration_update_metadata["pdf"] = state.config.content_type.pdf is not None configuration_update_metadata["pdf"] = "pdf" in enabled_content
configuration_update_metadata["markdown"] = state.config.content_type.markdown is not None configuration_update_metadata["markdown"] = "markdown" in enabled_content
configuration_update_metadata["plugins"] = state.config.content_type.plugins is not None
if state.config.processor is not None: if state.config.processor is not None:
configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None
@ -101,6 +166,7 @@ if not state.demo:
return state.config return state.config
@api.post("/config/data/content_type/github", status_code=200) @api.post("/config/data/content_type/github", status_code=200)
@requires("authenticated")
async def set_content_config_github_data( async def set_content_config_github_data(
request: Request, request: Request,
updated_config: Union[GithubContentConfig, None], updated_config: Union[GithubContentConfig, None],
@ -108,10 +174,13 @@ if not state.demo:
): ):
_initialize_config() _initialize_config()
if not state.config.content_type: user = request.user.object if request.user.is_authenticated else None
state.config.content_type = ContentConfig(**{"github": updated_config})
else: await adapters.set_user_github_config(
state.config.content_type.github = updated_config user=user,
pat_token=updated_config.pat_token,
repos=updated_config.repos,
)
update_telemetry_state( update_telemetry_state(
request=request, request=request,
@ -121,11 +190,7 @@ if not state.demo:
metadata={"content_type": "github"}, metadata={"content_type": "github"},
) )
try: return {"status": "ok"}
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
@api.post("/config/data/content_type/notion", status_code=200) @api.post("/config/data/content_type/notion", status_code=200)
async def set_content_config_notion_data( async def set_content_config_notion_data(
@ -135,10 +200,12 @@ if not state.demo:
): ):
_initialize_config() _initialize_config()
if not state.config.content_type: user = request.user.object if request.user.is_authenticated else None
state.config.content_type = ContentConfig(**{"notion": updated_config})
else: await adapters.set_notion_config(
state.config.content_type.notion = updated_config user=user,
token=updated_config.token,
)
update_telemetry_state( update_telemetry_state(
request=request, request=request,
@ -148,11 +215,7 @@ if not state.demo:
metadata={"content_type": "notion"}, metadata={"content_type": "notion"},
) )
try: return {"status": "ok"}
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
@api.post("/delete/config/data/content_type/{content_type}", status_code=200) @api.post("/delete/config/data/content_type/{content_type}", status_code=200)
async def remove_content_config_data( async def remove_content_config_data(
@ -160,8 +223,7 @@ if not state.demo:
content_type: str, content_type: str,
client: Optional[str] = None, client: Optional[str] = None,
): ):
if not state.config or not state.config.content_type: user = request.user.object if request.user.is_authenticated else None
return {"status": "ok"}
update_telemetry_state( update_telemetry_state(
request=request, request=request,
@ -171,31 +233,13 @@ if not state.demo:
metadata={"content_type": content_type}, metadata={"content_type": content_type},
) )
if state.config.content_type: content_object = map_config_to_object(content_type)
state.config.content_type[content_type] = None await content_object.objects.filter(user=user).adelete()
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
if content_type == "github": enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
state.content_index.github = None
elif content_type == "notion":
state.content_index.notion = None
elif content_type == "plugins":
state.content_index.plugins = None
elif content_type == "pdf":
state.content_index.pdf = None
elif content_type == "markdown":
state.content_index.markdown = None
elif content_type == "org":
state.content_index.org = None
elif content_type == "plaintext":
state.content_index.plaintext = None
else:
logger.warning(f"Request to delete unknown content type: {content_type} via API")
try: return {"status": "ok"}
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
@api.post("/delete/config/data/processor/conversation/openai", status_code=200) @api.post("/delete/config/data/processor/conversation/openai", status_code=200)
async def remove_processor_conversation_config_data( async def remove_processor_conversation_config_data(
@ -228,6 +272,7 @@ if not state.demo:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@api.post("/config/data/content_type/{content_type}", status_code=200) @api.post("/config/data/content_type/{content_type}", status_code=200)
# @requires("authenticated")
async def set_content_config_data( async def set_content_config_data(
request: Request, request: Request,
content_type: str, content_type: str,
@ -236,10 +281,10 @@ if not state.demo:
): ):
_initialize_config() _initialize_config()
if not state.config.content_type: user = request.user.object if request.user.is_authenticated else None
state.config.content_type = ContentConfig(**{content_type: updated_config})
else: content_object = map_config_to_object(content_type)
state.config.content_type[content_type] = updated_config await adapters.set_text_content_config(user, content_object, updated_config)
update_telemetry_state( update_telemetry_state(
request=request, request=request,
@ -249,11 +294,7 @@ if not state.demo:
metadata={"content_type": content_type}, metadata={"content_type": content_type},
) )
try: return {"status": "ok"}
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
@api.post("/config/data/processor/conversation/openai", status_code=200) @api.post("/config/data/processor/conversation/openai", status_code=200)
async def set_processor_openai_config_data( async def set_processor_openai_config_data(
@ -337,24 +378,23 @@ def get_default_config_data():
@api.get("/config/types", response_model=List[str]) @api.get("/config/types", response_model=List[str])
def get_config_types(): def get_config_types(
"""Get configured content types""" request: Request,
if state.config is None or state.config.content_type is None: ):
raise HTTPException( user = request.user.object if request.user.is_authenticated else None
status_code=500,
detail="Content types not configured. Configure at least one content type on server and restart it.", enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user)
)
configured_content_types = list(enabled_file_types)
if state.config and state.config.content_type:
for ctype in state.config.content_type.dict(exclude_none=True):
configured_content_types.append(ctype)
configured_content_types = state.config.content_type.dict(exclude_none=True)
return [ return [
search_type.value search_type.value
for search_type in SearchType for search_type in SearchType
if ( if (search_type.value in configured_content_types) or search_type == SearchType.All
search_type.value in configured_content_types
and getattr(state.content_index, search_type.value) is not None
)
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
or search_type == SearchType.All
] ]
@ -372,6 +412,7 @@ async def search(
referer: Optional[str] = Header(None), referer: Optional[str] = Header(None),
host: Optional[str] = Header(None), host: Optional[str] = Header(None),
): ):
user = request.user.object if request.user.is_authenticated else None
start_time = time.time() start_time = time.time()
# Run validation checks # Run validation checks
@ -390,10 +431,11 @@ async def search(
search_futures: List[concurrent.futures.Future] = [] search_futures: List[concurrent.futures.Future] = []
# return cached results, if available # return cached results, if available
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}" if user:
if query_cache_key in state.query_cache: query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
logger.debug(f"Return response from query cache") if query_cache_key in state.query_cache[user.uuid]:
return state.query_cache[query_cache_key] logger.debug(f"Return response from query cache")
return state.query_cache[user.uuid][query_cache_key]
# Encode query with filter terms removed # Encode query with filter terms removed
defiltered_query = user_query defiltered_query = user_query
@ -407,84 +449,31 @@ async def search(
] ]
if text_search_models: if text_search_models:
with timer("Encoding query took", logger=logger): with timer("Encoding query took", logger=logger):
encoded_asymmetric_query = util.normalize_embeddings( encoded_asymmetric_query = state.embeddings_model.embed_query(defiltered_query)
text_search_models[0].bi_encoder.encode(
[defiltered_query],
convert_to_tensor=True,
device=state.device,
)
)
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search: if t in [
# query org-mode notes SearchType.All,
search_futures += [ SearchType.Org,
executor.submit( SearchType.Markdown,
text_search.query, SearchType.Github,
user_query, SearchType.Notion,
state.search_models.text_search, SearchType.Plaintext,
state.content_index.org, ]:
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (
(t == SearchType.Markdown or t == SearchType.All)
and state.content_index.markdown
and state.search_models.text_search
):
# query markdown notes # query markdown notes
search_futures += [ search_futures += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user,
user_query, user_query,
state.search_models.text_search, t,
state.content_index.markdown,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, rank_results=r or False,
score_threshold=score_threshold, score_threshold=score_threshold,
dedupe=dedupe or True,
) )
] ]
if ( elif (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
(t == SearchType.Github or t == SearchType.All)
and state.content_index.github
and state.search_models.text_search
):
# query github issues
search_futures += [
executor.submit(
text_search.query,
user_query,
state.search_models.text_search,
state.content_index.github,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
# query pdf files
search_futures += [
executor.submit(
text_search.query,
user_query,
state.search_models.text_search,
state.content_index.pdf,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
# query images # query images
search_futures += [ search_futures += [
executor.submit( executor.submit(
@ -497,70 +486,6 @@ async def search(
) )
] ]
if (
(t == SearchType.All or t in SearchType)
and state.content_index.plugins
and state.search_models.plugin_search
):
# query specified plugin type
# Get plugin content, search model for specified search type, or the first one if none specified
plugin_search = state.search_models.plugin_search.get(t.value) or next(
iter(state.search_models.plugin_search.values())
)
plugin_content = state.content_index.plugins.get(t.value) or next(
iter(state.content_index.plugins.values())
)
search_futures += [
executor.submit(
text_search.query,
user_query,
plugin_search,
plugin_content,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (
(t == SearchType.Notion or t == SearchType.All)
and state.content_index.notion
and state.search_models.text_search
):
# query notion pages
search_futures += [
executor.submit(
text_search.query,
user_query,
state.search_models.text_search,
state.content_index.notion,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (
(t == SearchType.Plaintext or t == SearchType.All)
and state.content_index.plaintext
and state.search_models.text_search
):
# query plaintext files
search_futures += [
executor.submit(
text_search.query,
user_query,
state.search_models.text_search,
state.content_index.plaintext,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
# Query across each requested content types in parallel # Query across each requested content types in parallel
with timer("Query took", logger): with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures): for search_future in concurrent.futures.as_completed(search_futures):
@ -576,15 +501,19 @@ async def search(
count=results_count, count=results_count,
) )
else: else:
hits, entries = await search_future.result() hits = await search_future.result()
# Collate results # Collate results
results += text_search.collate_results(hits, entries, results_count) results += text_search.collate_results(hits, dedupe=dedupe)
# Sort results across all content types and take top results if r:
results = sorted(results, key=lambda x: float(x.score), reverse=True)[:results_count] results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
else:
# Sort results across all content types and take top results
results = sorted(results, key=lambda x: float(x.score))[:results_count]
# Cache results # Cache results
state.query_cache[query_cache_key] = results if user:
state.query_cache[user.uuid][query_cache_key] = results
update_telemetry_state( update_telemetry_state(
request=request, request=request,
@ -596,8 +525,6 @@ async def search(
host=host, host=host,
) )
state.previous_query = user_query
end_time = time.time() end_time = time.time()
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds") logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
@ -614,12 +541,13 @@ def update(
referer: Optional[str] = Header(None), referer: Optional[str] = Header(None),
host: Optional[str] = Header(None), host: Optional[str] = Header(None),
): ):
user = request.user.object if request.user.is_authenticated else None
if not state.config: if not state.config:
error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}." error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
logger.warning(error_msg) logger.warning(error_msg)
raise HTTPException(status_code=500, detail=error_msg) raise HTTPException(status_code=500, detail=error_msg)
try: try:
configure_server(state.config, regenerate=force, search_type=t) configure_server(state.config, regenerate=force, search_type=t, user=user)
except Exception as e: except Exception as e:
error_msg = f"🚨 Failed to update server via API: {e}" error_msg = f"🚨 Failed to update server via API: {e}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
@ -774,6 +702,7 @@ async def extract_references_and_questions(
n: int, n: int,
conversation_type: ConversationCommand = ConversationCommand.Default, conversation_type: ConversationCommand = ConversationCommand.Default,
): ):
user = request.user.object if request.user.is_authenticated else None
# Load Conversation History # Load Conversation History
meta_log = state.processor_config.conversation.meta_log meta_log = state.processor_config.conversation.meta_log
@ -781,7 +710,7 @@ async def extract_references_and_questions(
compiled_references: List[Any] = [] compiled_references: List[Any] = []
inferred_queries: List[str] = [] inferred_queries: List[str] = []
if state.content_index is None: if not EmbeddingsAdapters.user_has_embeddings(user=user):
logger.warning( logger.warning(
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes." "No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
) )

View file

@ -10,6 +10,7 @@ from khoj.utils.helpers import ConversationCommand, timer, log_telemetry
from khoj.processor.conversation.openai.gpt import converse from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_offline from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
from database.models import KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,11 +41,13 @@ def update_telemetry_state(
host: Optional[str] = None, host: Optional[str] = None,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
): ):
user: KhojUser = request.user.object if request.user.is_authenticated else None
user_state = { user_state = {
"client_host": request.client.host if request.client else None, "client_host": request.client.host if request.client else None,
"user_agent": user_agent or "unknown", "user_agent": user_agent or "unknown",
"referer": referer or "unknown", "referer": referer or "unknown",
"host": host or "unknown", "host": host or "unknown",
"server_id": str(user.uuid) if user else None,
} }
if metadata: if metadata:

View file

@ -1,6 +1,7 @@
# Standard Packages # Standard Packages
import logging import logging
from typing import Optional, Union, Dict from typing import Optional, Union, Dict
import asyncio
# External Packages # External Packages
from fastapi import APIRouter, HTTPException, Header, Request, Response, UploadFile from fastapi import APIRouter, HTTPException, Header, Request, Response, UploadFile
@ -9,31 +10,30 @@ from khoj.routers.helpers import update_telemetry_state
# Internal Packages # Internal Packages
from khoj.utils import state, constants from khoj.utils import state, constants
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.utils.rawconfig import ContentConfig, TextContentConfig
from khoj.search_type import text_search, image_search from khoj.search_type import text_search, image_search
from khoj.utils.yaml import save_config_to_file_updated_state from khoj.utils.yaml import save_config_to_file_updated_state
from khoj.utils.config import SearchModels from khoj.utils.config import SearchModels
from khoj.utils.constants import default_config
from khoj.utils.helpers import LRU, get_file_type from khoj.utils.helpers import LRU, get_file_type
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
ContentConfig, ContentConfig,
FullConfig, FullConfig,
SearchConfig, SearchConfig,
) )
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.utils.config import ( from khoj.utils.config import (
ContentIndex, ContentIndex,
SearchModels, SearchModels,
) )
from database.models import (
KhojUser,
GithubConfig,
NotionConfig,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,14 +68,14 @@ async def update(
referer: Optional[str] = Header(None), referer: Optional[str] = Header(None),
host: Optional[str] = Header(None), host: Optional[str] = Header(None),
): ):
user = request.user.object if request.user.is_authenticated else None
if x_api_key != "secret": if x_api_key != "secret":
raise HTTPException(status_code=401, detail="Invalid API Key") raise HTTPException(status_code=401, detail="Invalid API Key")
state.config_lock.acquire()
try: try:
logger.info(f"📬 Updating content index via API call by {client} client") logger.info(f"📬 Updating content index via API call by {client} client")
org_files: Dict[str, str] = {} org_files: Dict[str, str] = {}
markdown_files: Dict[str, str] = {} markdown_files: Dict[str, str] = {}
pdf_files: Dict[str, str] = {} pdf_files: Dict[str, bytes] = {}
plaintext_files: Dict[str, str] = {} plaintext_files: Dict[str, str] = {}
for file in files: for file in files:
@ -86,7 +86,7 @@ async def update(
elif file_type == "markdown": elif file_type == "markdown":
dict_to_update = markdown_files dict_to_update = markdown_files
elif file_type == "pdf": elif file_type == "pdf":
dict_to_update = pdf_files dict_to_update = pdf_files # type: ignore
elif file_type == "plaintext": elif file_type == "plaintext":
dict_to_update = plaintext_files dict_to_update = plaintext_files
@ -120,30 +120,31 @@ async def update(
github=None, github=None,
notion=None, notion=None,
plaintext=None, plaintext=None,
plugins=None,
) )
state.config.content_type = default_content_config state.config.content_type = default_content_config
save_config_to_file_updated_state() save_config_to_file_updated_state()
configure_search(state.search_models, state.config.search_type) configure_search(state.search_models, state.config.search_type)
# Extract required fields from config # Extract required fields from config
state.content_index = configure_content( loop = asyncio.get_event_loop()
state.content_index = await loop.run_in_executor(
None,
configure_content,
state.content_index, state.content_index,
state.config.content_type, state.config.content_type,
indexer_input.dict(), indexer_input.dict(),
state.search_models, state.search_models,
regenerate=force, force,
t=t, t,
full_corpus=False, False,
user,
) )
logger.info(f"Finished processing batch indexing request")
except Exception as e: except Exception as e:
logger.error( logger.error(
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}", f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
exc_info=True, exc_info=True,
) )
finally:
state.config_lock.release()
update_telemetry_state( update_telemetry_state(
request=request, request=request,
@ -167,11 +168,6 @@ def configure_search(search_models: SearchModels, search_config: Optional[Search
if search_models is None: if search_models is None:
search_models = SearchModels() search_models = SearchModels()
# Initialize Search Models
if search_config.asymmetric:
logger.info("🔍 📜 Setting up text search model")
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
if search_config.image: if search_config.image:
logger.info("🔍 🌄 Setting up image search model") logger.info("🔍 🌄 Setting up image search model")
search_models.image_search = image_search.initialize_model(search_config.image) search_models.image_search = image_search.initialize_model(search_config.image)
@ -187,16 +183,9 @@ def configure_content(
regenerate: bool = False, regenerate: bool = False,
t: Optional[Union[state.SearchType, str]] = None, t: Optional[Union[state.SearchType, str]] = None,
full_corpus: bool = True, full_corpus: bool = True,
user: KhojUser = None,
) -> Optional[ContentIndex]: ) -> Optional[ContentIndex]:
def has_valid_text_config(config: TextContentConfig): content_index = ContentIndex()
return config.input_files or config.input_filter
# Run Validation Checks
if content_config is None:
logger.warning("🚨 No Content configuration available.")
return None
if content_index is None:
content_index = ContentIndex()
if t in [type.value for type in state.SearchType]: if t in [type.value for type in state.SearchType]:
t = state.SearchType(t).value t = state.SearchType(t).value
@ -209,59 +198,30 @@ def configure_content(
try: try:
# Initialize Org Notes Search # Initialize Org Notes Search
if ( if (t == None or t == state.SearchType.Org.value) and files["org"]:
(t == None or t == state.SearchType.Org.value)
and ((content_config.org and has_valid_text_config(content_config.org)) or files["org"])
and search_models.text_search
):
if content_config.org == None:
logger.info("🦄 No configuration for orgmode notes. Using default configuration.")
default_configuration = default_config["content-type"]["org"] # type: ignore
content_config.org = TextContentConfig(
compressed_jsonl=default_configuration["compressed-jsonl"],
embeddings_file=default_configuration["embeddings-file"],
)
logger.info("🦄 Setting up search for orgmode notes") logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings # Extract Entries, Generate Notes Embeddings
content_index.org = text_search.setup( text_search.setup(
OrgToJsonl, OrgToJsonl,
files.get("org"), files.get("org"),
content_config.org,
search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus, full_corpus=full_corpus,
user=user,
) )
except Exception as e: except Exception as e:
logger.error(f"🚨 Failed to setup org: {e}", exc_info=True) logger.error(f"🚨 Failed to setup org: {e}", exc_info=True)
try: try:
# Initialize Markdown Search # Initialize Markdown Search
if ( if (t == None or t == state.SearchType.Markdown.value) and files["markdown"]:
(t == None or t == state.SearchType.Markdown.value)
and ((content_config.markdown and has_valid_text_config(content_config.markdown)) or files["markdown"])
and search_models.text_search
and files["markdown"]
):
if content_config.markdown == None:
logger.info("💎 No configuration for markdown notes. Using default configuration.")
default_configuration = default_config["content-type"]["markdown"] # type: ignore
content_config.markdown = TextContentConfig(
compressed_jsonl=default_configuration["compressed-jsonl"],
embeddings_file=default_configuration["embeddings-file"],
)
logger.info("💎 Setting up search for markdown notes") logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings # Extract Entries, Generate Markdown Embeddings
content_index.markdown = text_search.setup( text_search.setup(
MarkdownToJsonl, MarkdownToJsonl,
files.get("markdown"), files.get("markdown"),
content_config.markdown,
search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus, full_corpus=full_corpus,
user=user,
) )
except Exception as e: except Exception as e:
@ -269,30 +229,15 @@ def configure_content(
try: try:
# Initialize PDF Search # Initialize PDF Search
if ( if (t == None or t == state.SearchType.Pdf.value) and files["pdf"]:
(t == None or t == state.SearchType.Pdf.value)
and ((content_config.pdf and has_valid_text_config(content_config.pdf)) or files["pdf"])
and search_models.text_search
and files["pdf"]
):
if content_config.pdf == None:
logger.info("🖨️ No configuration for pdf notes. Using default configuration.")
default_configuration = default_config["content-type"]["pdf"] # type: ignore
content_config.pdf = TextContentConfig(
compressed_jsonl=default_configuration["compressed-jsonl"],
embeddings_file=default_configuration["embeddings-file"],
)
logger.info("🖨️ Setting up search for pdf") logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings # Extract Entries, Generate PDF Embeddings
content_index.pdf = text_search.setup( text_search.setup(
PdfToJsonl, PdfToJsonl,
files.get("pdf"), files.get("pdf"),
content_config.pdf,
search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus, full_corpus=full_corpus,
user=user,
) )
except Exception as e: except Exception as e:
@ -300,30 +245,15 @@ def configure_content(
try: try:
# Initialize Plaintext Search # Initialize Plaintext Search
if ( if (t == None or t == state.SearchType.Plaintext.value) and files["plaintext"]:
(t == None or t == state.SearchType.Plaintext.value)
and ((content_config.plaintext and has_valid_text_config(content_config.plaintext)) or files["plaintext"])
and search_models.text_search
and files["plaintext"]
):
if content_config.plaintext == None:
logger.info("📄 No configuration for plaintext notes. Using default configuration.")
default_configuration = default_config["content-type"]["plaintext"] # type: ignore
content_config.plaintext = TextContentConfig(
compressed_jsonl=default_configuration["compressed-jsonl"],
embeddings_file=default_configuration["embeddings-file"],
)
logger.info("📄 Setting up search for plaintext") logger.info("📄 Setting up search for plaintext")
# Extract Entries, Generate Plaintext Embeddings # Extract Entries, Generate Plaintext Embeddings
content_index.plaintext = text_search.setup( text_search.setup(
PlaintextToJsonl, PlaintextToJsonl,
files.get("plaintext"), files.get("plaintext"),
content_config.plaintext,
search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus, full_corpus=full_corpus,
user=user,
) )
except Exception as e: except Exception as e:
@ -331,7 +261,12 @@ def configure_content(
try: try:
# Initialize Image Search # Initialize Image Search
if (t == None or t == state.SearchType.Image.value) and content_config.image and search_models.image_search: if (
(t == None or t == state.SearchType.Image.value)
and content_config
and content_config.image
and search_models.image_search
):
logger.info("🌄 Setting up search for images") logger.info("🌄 Setting up search for images")
# Extract Entries, Generate Image Embeddings # Extract Entries, Generate Image Embeddings
content_index.image = image_search.setup( content_index.image = image_search.setup(
@ -342,17 +277,17 @@ def configure_content(
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True) logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
try: try:
if (t == None or t == state.SearchType.Github.value) and content_config.github and search_models.text_search: github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if (t == None or t == state.SearchType.Github.value) and github_config is not None:
logger.info("🐙 Setting up search for github") logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings # Extract Entries, Generate Github Embeddings
content_index.github = text_search.setup( text_search.setup(
GithubToJsonl, GithubToJsonl,
None, None,
content_config.github,
search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus, full_corpus=full_corpus,
user=user,
config=github_config,
) )
except Exception as e: except Exception as e:
@ -360,42 +295,24 @@ def configure_content(
try: try:
# Initialize Notion Search # Initialize Notion Search
if (t == None or t in state.SearchType.Notion.value) and content_config.notion and search_models.text_search: notion_config = NotionConfig.objects.filter(user=user).first()
if (t == None or t in state.SearchType.Notion.value) and notion_config:
logger.info("🔌 Setting up search for notion") logger.info("🔌 Setting up search for notion")
content_index.notion = text_search.setup( text_search.setup(
NotionToJsonl, NotionToJsonl,
None, None,
content_config.notion,
search_models.text_search.bi_encoder,
regenerate=regenerate, regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus, full_corpus=full_corpus,
user=user,
config=notion_config,
) )
except Exception as e: except Exception as e:
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True) logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
try:
# Initialize External Plugin Search
if t == None and content_config.plugins and search_models.text_search:
logger.info("🔌 Setting up search for plugins")
content_index.plugins = {}
for plugin_type, plugin_config in content_config.plugins.items():
content_index.plugins[plugin_type] = text_search.setup(
JsonlToJsonl,
None,
plugin_config,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
)
except Exception as e:
logger.error(f"🚨 Failed to setup Plugin: {e}", exc_info=True)
# Invalidate Query Cache # Invalidate Query Cache
state.query_cache = LRU() if user:
state.query_cache[user.uuid] = LRU()
return content_index return content_index
@ -412,44 +329,9 @@ def load_content(
if content_index is None: if content_index is None:
content_index = ContentIndex() content_index = ContentIndex()
if content_config.org:
logger.info("🦄 Loading orgmode notes")
content_index.org = text_search.load(content_config.org, filters=[DateFilter(), WordFilter(), FileFilter()])
if content_config.markdown:
logger.info("💎 Loading markdown notes")
content_index.markdown = text_search.load(
content_config.markdown, filters=[DateFilter(), WordFilter(), FileFilter()]
)
if content_config.pdf:
logger.info("🖨️ Loading pdf")
content_index.pdf = text_search.load(content_config.pdf, filters=[DateFilter(), WordFilter(), FileFilter()])
if content_config.plaintext:
logger.info("📄 Loading plaintext")
content_index.plaintext = text_search.load(
content_config.plaintext, filters=[DateFilter(), WordFilter(), FileFilter()]
)
if content_config.image: if content_config.image:
logger.info("🌄 Loading images") logger.info("🌄 Loading images")
content_index.image = image_search.setup( content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False content_config.image, search_models.image_search.image_encoder, regenerate=False
) )
if content_config.github:
logger.info("🐙 Loading github")
content_index.github = text_search.load(
content_config.github, filters=[DateFilter(), WordFilter(), FileFilter()]
)
if content_config.notion:
logger.info("🔌 Loading notion")
content_index.notion = text_search.load(
content_config.notion, filters=[DateFilter(), WordFilter(), FileFilter()]
)
if content_config.plugins:
logger.info("🔌 Loading plugins")
content_index.plugins = {}
for plugin_type, plugin_config in content_config.plugins.items():
content_index.plugins[plugin_type] = text_search.load(
plugin_config, filters=[DateFilter(), WordFilter(), FileFilter()]
)
state.query_cache = LRU()
return content_index return content_index

View file

@ -3,10 +3,20 @@ from fastapi import APIRouter
from fastapi import Request from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse from fastapi.responses import HTMLResponse, FileResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from khoj.utils.rawconfig import TextContentConfig, OpenAIProcessorConfig, FullConfig from starlette.authentication import requires
from khoj.utils.rawconfig import (
TextContentConfig,
OpenAIProcessorConfig,
FullConfig,
GithubContentConfig,
GithubRepoConfig,
NotionContentConfig,
)
# Internal Packages # Internal Packages
from khoj.utils import constants, state from khoj.utils import constants, state
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
from database.models import KhojUser, LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
import json import json
@ -29,10 +39,23 @@ def chat_page(request: Request):
return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo}) return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo})
def map_config_to_object(content_type: str):
if content_type == "org":
return LocalOrgConfig
if content_type == "markdown":
return LocalMarkdownConfig
if content_type == "pdf":
return LocalPdfConfig
if content_type == "plaintext":
return LocalPlaintextConfig
if not state.demo: if not state.demo:
@web_client.get("/config", response_class=HTMLResponse) @web_client.get("/config", response_class=HTMLResponse)
def config_page(request: Request): def config_page(request: Request):
user = request.user.object if request.user.is_authenticated else None
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
default_full_config = FullConfig( default_full_config = FullConfig(
content_type=None, content_type=None,
search_type=None, search_type=None,
@ -41,13 +64,13 @@ if not state.demo:
current_config = state.config or json.loads(default_full_config.json()) current_config = state.config or json.loads(default_full_config.json())
successfully_configured = { successfully_configured = {
"pdf": False, "pdf": ("pdf" in enabled_content),
"markdown": False, "markdown": ("markdown" in enabled_content),
"org": False, "org": ("org" in enabled_content),
"image": False, "image": False,
"github": False, "github": ("github" in enabled_content),
"notion": False, "notion": ("notion" in enabled_content),
"plaintext": False, "plaintext": ("plaintext" in enabled_content),
"enable_offline_model": False, "enable_offline_model": False,
"conversation_openai": False, "conversation_openai": False,
"conversation_gpt4all": False, "conversation_gpt4all": False,
@ -56,13 +79,7 @@ if not state.demo:
if state.content_index: if state.content_index:
successfully_configured.update( successfully_configured.update(
{ {
"pdf": state.content_index.pdf is not None,
"markdown": state.content_index.markdown is not None,
"org": state.content_index.org is not None,
"image": state.content_index.image is not None, "image": state.content_index.image is not None,
"github": state.content_index.github is not None,
"notion": state.content_index.notion is not None,
"plaintext": state.content_index.plaintext is not None,
} }
) )
@ -84,22 +101,29 @@ if not state.demo:
) )
@web_client.get("/config/content_type/github", response_class=HTMLResponse) @web_client.get("/config/content_type/github", response_class=HTMLResponse)
@requires(["authenticated"])
def github_config_page(request: Request): def github_config_page(request: Request):
default_copy = constants.default_config.copy() user = request.user.object if request.user.is_authenticated else None
default_github = default_copy["content-type"]["github"] # type: ignore current_github_config = get_user_github_config(user)
default_config = TextContentConfig( if current_github_config:
compressed_jsonl=default_github["compressed-jsonl"], raw_repos = current_github_config.githubrepoconfig.all()
embeddings_file=default_github["embeddings-file"], repos = []
) for repo in raw_repos:
repos.append(
current_config = ( GithubRepoConfig(
state.config.content_type.github name=repo.name,
if state.config and state.config.content_type and state.config.content_type.github owner=repo.owner,
else default_config branch=repo.branch,
) )
)
current_config = json.loads(current_config.json()) current_config = GithubContentConfig(
pat_token=current_github_config.pat_token,
repos=repos,
)
current_config = json.loads(current_config.json())
else:
current_config = {} # type: ignore
return templates.TemplateResponse( return templates.TemplateResponse(
"content_type_github_input.html", context={"request": request, "current_config": current_config} "content_type_github_input.html", context={"request": request, "current_config": current_config}
@ -107,18 +131,11 @@ if not state.demo:
@web_client.get("/config/content_type/notion", response_class=HTMLResponse) @web_client.get("/config/content_type/notion", response_class=HTMLResponse)
def notion_config_page(request: Request): def notion_config_page(request: Request):
default_copy = constants.default_config.copy() user = request.user.object if request.user.is_authenticated else None
default_notion = default_copy["content-type"]["notion"] # type: ignore current_notion_config = get_user_notion_config(user)
default_config = TextContentConfig( current_config = NotionContentConfig(
compressed_jsonl=default_notion["compressed-jsonl"], token=current_notion_config.token if current_notion_config else "",
embeddings_file=default_notion["embeddings-file"],
)
current_config = (
state.config.content_type.notion
if state.config and state.config.content_type and state.config.content_type.notion
else default_config
) )
current_config = json.loads(current_config.json()) current_config = json.loads(current_config.json())
@ -132,18 +149,16 @@ if not state.demo:
if content_type not in VALID_TEXT_CONTENT_TYPES: if content_type not in VALID_TEXT_CONTENT_TYPES:
return templates.TemplateResponse("config.html", context={"request": request}) return templates.TemplateResponse("config.html", context={"request": request})
default_copy = constants.default_config.copy() object = map_config_to_object(content_type)
default_content_type = default_copy["content-type"][content_type] # type: ignore user = request.user.object if request.user.is_authenticated else None
config = object.objects.filter(user=user).first()
if config == None:
config = object.objects.create(user=user)
default_config = TextContentConfig( current_config = TextContentConfig(
compressed_jsonl=default_content_type["compressed-jsonl"], input_files=config.input_files,
embeddings_file=default_content_type["embeddings-file"], input_filter=config.input_filter,
) index_heading_entries=config.index_heading_entries,
current_config = (
state.config.content_type[content_type]
if state.config and state.config.content_type and state.config.content_type[content_type] # type: ignore
else default_config
) )
current_config = json.loads(current_config.json()) current_config = json.loads(current_config.json())

View file

@ -1,16 +1,9 @@
# Standard Packages # Standard Packages
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Set, Tuple from typing import List
# Internal Packages
from khoj.utils.rawconfig import Entry
class BaseFilter(ABC): class BaseFilter(ABC):
@abstractmethod
def load(self, entries: List[Entry], *args, **kwargs):
...
@abstractmethod @abstractmethod
def get_filter_terms(self, query: str) -> List[str]: def get_filter_terms(self, query: str) -> List[str]:
... ...
@ -18,10 +11,6 @@ class BaseFilter(ABC):
def can_filter(self, raw_query: str) -> bool: def can_filter(self, raw_query: str) -> bool:
return len(self.get_filter_terms(raw_query)) > 0 return len(self.get_filter_terms(raw_query)) > 0
@abstractmethod
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
...
@abstractmethod @abstractmethod
def defilter(self, query: str) -> str: def defilter(self, query: str) -> str:
... ...

View file

@ -25,72 +25,42 @@ class DateFilter(BaseFilter):
# - dt>="last week" # - dt>="last week"
# - dt:"2 years ago" # - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']" date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']"
raw_date_regex = r"\d{4}-\d{2}-\d{2}"
def __init__(self, entry_key="compiled"): def __init__(self, entry_key="compiled"):
self.entry_key = entry_key self.entry_key = entry_key
self.date_to_entry_ids = defaultdict(set) self.date_to_entry_ids = defaultdict(set)
self.cache = LRU() self.cache = LRU()
def load(self, entries, *args, **kwargs): def extract_dates(self, content):
with timer("Created date filter index", logger): pattern_matched_dates = re.findall(self.raw_date_regex, content)
for id, entry in enumerate(entries):
# Extract dates from entry # Filter down to valid dates
for date_in_entry_string in re.findall(r"\d{4}-\d{2}-\d{2}", getattr(entry, self.entry_key)): valid_dates = []
# Convert date string in entry to unix timestamp for date_str in pattern_matched_dates:
try: try:
date_in_entry = datetime.strptime(date_in_entry_string, "%Y-%m-%d").timestamp() valid_dates.append(datetime.strptime(date_str, "%Y-%m-%d"))
except ValueError: except ValueError:
continue continue
except OSError:
logger.debug(f"OSError: Ignoring unprocessable date in entry: {date_in_entry_string}") return valid_dates
continue
self.date_to_entry_ids[date_in_entry].add(id)
def get_filter_terms(self, query: str) -> List[str]: def get_filter_terms(self, query: str) -> List[str]:
"Get all filter terms in query" "Get all filter terms in query"
return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)] return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)]
def get_query_date_range(self, query) -> List:
with timer("Extract date range to filter from query", logger):
query_daterange = self.extract_date_range(query)
return query_daterange
def defilter(self, query): def defilter(self, query):
# remove date range filter from query # remove date range filter from query
query = re.sub(rf"\s+{self.date_regex}", " ", query) query = re.sub(rf"\s+{self.date_regex}", " ", query)
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
return query return query
def apply(self, query, entries):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
with timer("Extract date range to filter from query", logger):
query_daterange = self.extract_date_range(query)
# if no date in query, return all entries
if query_daterange == []:
return query, set(range(len(entries)))
query = self.defilter(query)
# return results from cache if exists
cache_key = tuple(query_daterange)
if cache_key in self.cache:
logger.debug(f"Return date filter results from cache")
entries_to_include = self.cache[cache_key]
return query, entries_to_include
if not self.date_to_entry_ids:
self.load(entries)
# find entries containing any dates that fall with date range specified in query
with timer("Mark entries satisfying filter", logger):
entries_to_include = set()
for date_in_entry in self.date_to_entry_ids.keys():
# Check if date in entry is within date range specified in query
if query_daterange[0] <= date_in_entry < query_daterange[1]:
entries_to_include |= self.date_to_entry_ids[date_in_entry]
# cache results
self.cache[cache_key] = entries_to_include
return query, entries_to_include
def extract_date_range(self, query): def extract_date_range(self, query):
# find date range filter in query # find date range filter in query
date_range_matches = re.findall(self.date_regex, query) date_range_matches = re.findall(self.date_regex, query)
@ -138,6 +108,15 @@ class DateFilter(BaseFilter):
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]: if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
return [] return []
else: else:
# If the first element is 0, replace it with None
if effective_date_range[0] == 0:
effective_date_range[0] = None
# If the second element is inf, replace it with None
if effective_date_range[1] == inf:
effective_date_range[1] = None
return effective_date_range return effective_date_range
def parse(self, date_str, relative_base=None): def parse(self, date_str, relative_base=None):

View file

@ -21,62 +21,13 @@ class FileFilter(BaseFilter):
self.file_to_entry_map = defaultdict(set) self.file_to_entry_map = defaultdict(set)
self.cache = LRU() self.cache = LRU()
def load(self, entries, *args, **kwargs):
with timer("Created file filter index", logger):
for id, entry in enumerate(entries):
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
def get_filter_terms(self, query: str) -> List[str]: def get_filter_terms(self, query: str) -> List[str]:
"Get all filter terms in query" "Get all filter terms in query"
return [f'file:"{term}"' for term in re.findall(self.file_filter_regex, query)] return [f"{self.convert_to_regex(term)}" for term in re.findall(self.file_filter_regex, query)]
def convert_to_regex(self, file_filter: str) -> str:
"Convert file filter to regex"
return file_filter.replace(".", r"\.").replace("*", r".*")
def defilter(self, query: str) -> str: def defilter(self, query: str) -> str:
return re.sub(self.file_filter_regex, "", query).strip() return re.sub(self.file_filter_regex, "", query).strip()
def apply(self, query, entries):
# Extract file filters from raw query
with timer("Extract files_to_search from query", logger):
raw_files_to_search = re.findall(self.file_filter_regex, query)
if not raw_files_to_search:
return query, set(range(len(entries)))
# Convert simple file filters with no path separator into regex
# e.g. "file:notes.org" -> "file:.*notes.org"
files_to_search = []
for file in sorted(raw_files_to_search):
if "/" not in file and "\\" not in file and "*" not in file:
files_to_search += [f"*{file}"]
else:
files_to_search += [file]
# Remove filter terms from original query
query = self.defilter(query)
# Return item from cache if exists
cache_key = tuple(files_to_search)
if cache_key in self.cache:
logger.debug(f"Return file filter results from cache")
included_entry_indices = self.cache[cache_key]
return query, included_entry_indices
if not self.file_to_entry_map:
self.load(entries, regenerate=False)
# Mark entries that contain any blocked_words for exclusion
with timer("Mark entries satisfying filter", logger):
included_entry_indices = set.union(
*[
self.file_to_entry_map[entry_file]
for entry_file in self.file_to_entry_map.keys()
for search_file in files_to_search
if fnmatch.fnmatch(entry_file, search_file)
],
set(),
)
if not included_entry_indices:
return query, {}
# Cache results
self.cache[cache_key] = included_entry_indices
return query, included_entry_indices

View file

@ -6,7 +6,7 @@ from typing import List
# Internal Packages # Internal Packages
from khoj.search_filter.base_filter import BaseFilter from khoj.search_filter.base_filter import BaseFilter
from khoj.utils.helpers import LRU, timer from khoj.utils.helpers import LRU
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -22,21 +22,6 @@ class WordFilter(BaseFilter):
self.word_to_entry_index = defaultdict(set) self.word_to_entry_index = defaultdict(set)
self.cache = LRU() self.cache = LRU()
def load(self, entries, *args, **kwargs):
with timer("Created word filter index", logger):
self.cache = {} # Clear cache on filter (re-)load
entry_splitter = (
r",|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'"
)
# Create map of words to entries they exist in
for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
if word == "":
continue
self.word_to_entry_index[word].add(entry_index)
return self.word_to_entry_index
def get_filter_terms(self, query: str) -> List[str]: def get_filter_terms(self, query: str) -> List[str]:
"Get all filter terms in query" "Get all filter terms in query"
required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)] required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)]
@ -45,47 +30,3 @@ class WordFilter(BaseFilter):
def defilter(self, query: str) -> str: def defilter(self, query: str) -> str:
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip() return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
def apply(self, query, entries):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters
with timer("Extract required, blocked filters from query", logger):
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
query = self.defilter(query)
if len(required_words) == 0 and len(blocked_words) == 0:
return query, set(range(len(entries)))
# Return item from cache if exists
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
if cache_key in self.cache:
logger.debug(f"Return word filter results from cache")
included_entry_indices = self.cache[cache_key]
return query, included_entry_indices
if not self.word_to_entry_index:
self.load(entries, regenerate=False)
# mark entries that contain all required_words for inclusion
with timer("Mark entries satisfying filter", logger):
entries_with_all_required_words = set(range(len(entries)))
if len(required_words) > 0:
entries_with_all_required_words = set.intersection(
*[self.word_to_entry_index.get(word, set()) for word in required_words]
)
# mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set()
if len(blocked_words) > 0:
entries_with_any_blocked_words = set.union(
*[self.word_to_entry_index.get(word, set()) for word in blocked_words]
)
# get entries satisfying inclusion and exclusion filters
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
# Cache results
self.cache[cache_key] = included_entry_indices
return query, included_entry_indices

View file

@ -2,25 +2,39 @@
import logging import logging
import math import math
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Type, Union from typing import List, Tuple, Type, Union, Dict
# External Packages # External Packages
import torch import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util from sentence_transformers import SentenceTransformer, CrossEncoder, util
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.search_filter.base_filter import BaseFilter from asgiref.sync import sync_to_async
# Internal Packages # Internal Packages
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
from khoj.utils.config import TextContent, TextSearchModel from khoj.utils.config import TextSearchModel
from khoj.utils.models import BaseEncoder from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry from khoj.utils.state import SearchType
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, Entry
from khoj.utils.jsonl import load_jsonl from khoj.utils.jsonl import load_jsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from database.adapters import EmbeddingsAdapters
from database.models import KhojUser, Embeddings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
search_type_to_embeddings_type = {
SearchType.Org.value: Embeddings.EmbeddingsType.ORG,
SearchType.Markdown.value: Embeddings.EmbeddingsType.MARKDOWN,
SearchType.Plaintext.value: Embeddings.EmbeddingsType.PLAINTEXT,
SearchType.Pdf.value: Embeddings.EmbeddingsType.PDF,
SearchType.Github.value: Embeddings.EmbeddingsType.GITHUB,
SearchType.Notion.value: Embeddings.EmbeddingsType.NOTION,
SearchType.All.value: None,
}
def initialize_model(search_config: TextSearchConfig): def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text" "Initialize model for semantic search on text"
@ -117,171 +131,102 @@ def load_embeddings(
async def query( async def query(
user: KhojUser,
raw_query: str, raw_query: str,
search_model: TextSearchModel, type: SearchType = SearchType.All,
content: TextContent,
question_embedding: Union[torch.Tensor, None] = None, question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False, rank_results: bool = False,
score_threshold: float = -math.inf, score_threshold: float = -math.inf,
dedupe: bool = True,
) -> Tuple[List[dict], List[Entry]]: ) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query" "Search for entries that answer the query"
if (
content.entries is None
or len(content.entries) == 0
or content.corpus_embeddings is None
or len(content.corpus_embeddings) == 0
):
return [], []
query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings file_type = search_type_to_embeddings_type[type.value]
# Filter query, entries and embeddings before semantic search query = raw_query
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
# If no entries left after filtering, return empty results
if entries is None or len(entries) == 0:
return [], []
# If query only had filters it'll be empty now. So short-circuit and return results.
if query.strip() == "":
hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)]
return hits, entries
# Encode the query using the bi-encoder # Encode the query using the bi-encoder
if question_embedding is None: if question_embedding is None:
with timer("Query Encode Time", logger, state.device): with timer("Query Encode Time", logger, state.device):
question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) question_embedding = state.embeddings_model.embed_query(query)
question_embedding = util.normalize_embeddings(question_embedding)
# Find relevant entries for the query # Find relevant entries for the query
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus top_k = 10
with timer("Search Time", logger, state.device): with timer("Search Time", logger, state.device):
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0] hits = EmbeddingsAdapters.search_with_embeddings(
user=user,
embeddings=question_embedding,
max_results=top_k,
file_type_filter=file_type,
raw_query=raw_query,
).all()
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
return hits
def collate_results(hits, dedupe=True):
hit_ids = set()
for hit in hits:
if dedupe and hit.corpus_id in hit_ids:
continue
else:
hit_ids.add(hit.corpus_id)
yield SearchResponse.parse_obj(
{
"entry": hit.raw,
"score": hit.distance,
"additional": {
"file": hit.file_path,
"compiled": hit.compiled,
"heading": hit.heading,
},
}
)
def rerank_and_sort_results(hits, query):
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
if rank_results and search_model.cross_encoder: hits = cross_encoder_score(query, hits)
hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
# Filter results by score threshold # Sort results by cross-encoder score followed by bi-encoder score
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold] hits = sort_results(rank_results=True, hits=hits)
# Order results by cross-encoder score followed by bi-encoder score return hits
hits = sort_results(rank_results, hits)
# Deduplicate entries by raw entry text before showing to users
if dedupe:
hits = deduplicate_results(entries, hits)
return hits, entries
def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]:
return [
SearchResponse.parse_obj(
{
"entry": entries[hit["corpus_id"]].raw,
"score": f"{hit.get('cross-score') or hit.get('score')}",
"additional": {
"file": entries[hit["corpus_id"]].file,
"compiled": entries[hit["corpus_id"]].compiled,
"heading": entries[hit["corpus_id"]].heading,
},
}
)
for hit in hits[0:count]
]
def setup( def setup(
text_to_jsonl: Type[TextToJsonl], text_to_jsonl: Type[TextEmbeddings],
files: dict[str, str], files: dict[str, str],
config: TextConfigBase,
bi_encoder: BaseEncoder,
regenerate: bool, regenerate: bool,
filters: List[BaseFilter] = [],
normalize: bool = True,
full_corpus: bool = True, full_corpus: bool = True,
) -> TextContent: user: KhojUser = None,
# Map notes in text files to (compressed) JSONL formatted file config=None,
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) ) -> None:
previous_entries = [] if config:
if config.compressed_jsonl.exists() and not regenerate: num_new_embeddings, num_deleted_embeddings = text_to_jsonl(config).process(
previous_entries = extract_entries(config.compressed_jsonl) files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
entries_with_indices = text_to_jsonl(config).process( )
previous_entries=previous_entries, files=files, full_corpus=full_corpus else:
) num_new_embeddings, num_deleted_embeddings = text_to_jsonl().process(
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
# Extract Updated Entries
entries = extract_entries(config.compressed_jsonl)
if is_none_or_empty(entries):
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
raise ValueError(
f"No valid entries found in specified configuration: {config_params}, with files: {files.keys()}"
) )
# Compute or Load Embeddings file_names = [file_name for file_name in files]
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings( logger.info(
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
) )
for filter in filters:
filter.load(entries, regenerate=regenerate)
return TextContent(entries, corpus_embeddings, filters) def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:
def load(
config: TextConfigBase,
filters: List[BaseFilter] = [],
) -> TextContent:
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
entries = extract_entries(config.compressed_jsonl)
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = load_embeddings(config.embeddings_file)
for filter in filters:
filter.load(entries, regenerate=False)
return TextContent(entries, corpus_embeddings, filters)
def apply_filters(
query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]
) -> Tuple[str, List[Entry], torch.Tensor]:
"""Filter query, entries and embeddings before semantic search"""
with timer("Total Filter Time", logger, state.device):
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
for filter in filters_in_query:
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return "", [], torch.tensor([], device=state.device)
else:
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)
)
return query, entries, corpus_embeddings
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]:
"""Score all retrieved entries using the cross-encoder""" """Score all retrieved entries using the cross-encoder"""
with timer("Cross-Encoder Predict Time", logger, state.device): with timer("Cross-Encoder Predict Time", logger, state.device):
cross_inp = [[query, entries[hit["corpus_id"]].compiled] for hit in hits] cross_scores = state.cross_encoder_model.predict(query, hits)
cross_scores = cross_encoder.predict(cross_inp)
# Store cross-encoder scores in results dictionary for ranking # Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)): for idx in range(len(cross_scores)):
hits[idx]["cross-score"] = cross_scores[idx] hits[idx]["cross_score"] = cross_scores[idx]
return hits return hits
@ -291,23 +236,5 @@ def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
with timer("Rank Time", logger, state.device): with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
if rank_results: if rank_results:
hits.sort(key=lambda x: x["cross-score"], reverse=True) # sort by cross-encoder score hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
return hits
def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]:
"""Deduplicate entries by raw entry text before showing to users
Compiled entries are split by max tokens supported by ML models.
This can result in duplicate hits, entries shown to user."""
with timer("Deduplication Time", logger, state.device):
seen, original_hits_count = set(), len(hits)
hits = [
hit
for hit in hits
if entries[hit["corpus_id"]].raw not in seen and not seen.add(entries[hit["corpus_id"]].raw) # type: ignore[func-returns-value]
]
duplicate_hits = original_hits_count - len(hits)
logger.debug(f"Removed {duplicate_hits} duplicates")
return hits return hits

View file

@ -2,6 +2,7 @@
import argparse import argparse
import pathlib import pathlib
from importlib.metadata import version from importlib.metadata import version
import os
# Internal Packages # Internal Packages
from khoj.utils.helpers import resolve_absolute_path from khoj.utils.helpers import resolve_absolute_path
@ -34,6 +35,12 @@ def cli(args=None):
) )
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit") parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
parser.add_argument("--demo", action="store_true", default=False, help="Run Khoj in demo mode") parser.add_argument("--demo", action="store_true", default=False, help="Run Khoj in demo mode")
parser.add_argument(
"--anonymous-mode",
action="store_true",
default=False,
help="Run Khoj in anonymous mode. This does not require any login for connecting users.",
)
args = parser.parse_args(args) args = parser.parse_args(args)
@ -51,6 +58,8 @@ def cli(args=None):
else: else:
args = run_migrations(args) args = run_migrations(args)
args.config = parse_config_from_file(args.config_file) args.config = parse_config_from_file(args.config_file)
if os.environ.get("DEBUG"):
args.config.app.should_log_telemetry = False
return args return args

View file

@ -41,9 +41,7 @@ class ProcessorType(str, Enum):
@dataclass @dataclass
class TextContent: class TextContent:
entries: List[Entry] enabled: bool
corpus_embeddings: torch.Tensor
filters: List[BaseFilter]
@dataclass @dataclass
@ -67,21 +65,13 @@ class ImageSearchModel:
@dataclass @dataclass
class ContentIndex: class ContentIndex:
org: Optional[TextContent] = None
markdown: Optional[TextContent] = None
pdf: Optional[TextContent] = None
github: Optional[TextContent] = None
notion: Optional[TextContent] = None
image: Optional[ImageContent] = None image: Optional[ImageContent] = None
plaintext: Optional[TextContent] = None
plugins: Optional[Dict[str, TextContent]] = None
@dataclass @dataclass
class SearchModels: class SearchModels:
text_search: Optional[TextSearchModel] = None text_search: Optional[TextSearchModel] = None
image_search: Optional[ImageSearchModel] = None image_search: Optional[ImageSearchModel] = None
plugin_search: Optional[Dict[str, TextSearchModel]] = None
@dataclass @dataclass

View file

@ -5,6 +5,7 @@ web_directory = app_root_directory / "khoj/interface/web/"
empty_escape_sequences = "\n|\r|\t| " empty_escape_sequences = "\n|\r|\t| "
app_env_filepath = "~/.khoj/env" app_env_filepath = "~/.khoj/env"
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry" telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
content_directory = "~/.khoj/content/"
empty_config = { empty_config = {
"content-type": { "content-type": {

View file

@ -5,29 +5,39 @@ from typing import Optional
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from khoj.utils.helpers import get_absolute_path, is_none_or_empty from khoj.utils.helpers import get_absolute_path, is_none_or_empty
from khoj.utils.rawconfig import TextContentConfig, ContentConfig from khoj.utils.rawconfig import TextContentConfig
from khoj.utils.config import SearchType from khoj.utils.config import SearchType
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def collect_files(config: ContentConfig, search_type: Optional[SearchType] = SearchType.All): def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict:
files = {} files = {}
if config is None:
return files
if search_type == SearchType.All or search_type == SearchType.Org: if search_type == SearchType.All or search_type == SearchType.Org:
files["org"] = get_org_files(config.org) if config.org else {} org_config = LocalOrgConfig.objects.filter(user=user).first()
files["org"] = get_org_files(construct_config_from_db(org_config)) if org_config else {}
if search_type == SearchType.All or search_type == SearchType.Markdown: if search_type == SearchType.All or search_type == SearchType.Markdown:
files["markdown"] = get_markdown_files(config.markdown) if config.markdown else {} markdown_config = LocalMarkdownConfig.objects.filter(user=user).first()
files["markdown"] = get_markdown_files(construct_config_from_db(markdown_config)) if markdown_config else {}
if search_type == SearchType.All or search_type == SearchType.Plaintext: if search_type == SearchType.All or search_type == SearchType.Plaintext:
files["plaintext"] = get_plaintext_files(config.plaintext) if config.plaintext else {} plaintext_config = LocalPlaintextConfig.objects.filter(user=user).first()
files["plaintext"] = get_plaintext_files(construct_config_from_db(plaintext_config)) if plaintext_config else {}
if search_type == SearchType.All or search_type == SearchType.Pdf: if search_type == SearchType.All or search_type == SearchType.Pdf:
files["pdf"] = get_pdf_files(config.pdf) if config.pdf else {} pdf_config = LocalPdfConfig.objects.filter(user=user).first()
files["pdf"] = get_pdf_files(construct_config_from_db(pdf_config)) if pdf_config else {}
return files return files
def construct_config_from_db(db_config) -> TextContentConfig:
return TextContentConfig(
input_files=db_config.input_files,
input_filter=db_config.input_filter,
index_heading_entries=db_config.index_heading_entries,
)
def get_plaintext_files(config: TextContentConfig) -> dict[str, str]: def get_plaintext_files(config: TextContentConfig) -> dict[str, str]:
def is_plaintextfile(file: str): def is_plaintextfile(file: str):
"Check if file is plaintext file" "Check if file is plaintext file"

View file

@ -209,10 +209,12 @@ def log_telemetry(
if not app_config or not app_config.should_log_telemetry: if not app_config or not app_config.should_log_telemetry:
return [] return []
if properties.get("server_id") is None:
properties["server_id"] = get_server_id()
# Populate telemetry data to log # Populate telemetry data to log
request_body = { request_body = {
"telemetry_type": telemetry_type, "telemetry_type": telemetry_type,
"server_id": get_server_id(),
"server_version": version("khoj-assistant"), "server_version": version("khoj-assistant"),
"os": platform.system(), "os": platform.system(),
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),

View file

@ -1,13 +1,14 @@
# System Packages # System Packages
import json import json
from pathlib import Path from pathlib import Path
from typing import List, Dict, Optional, Union, Any from typing import List, Dict, Optional
import uuid
# External Packages # External Packages
from pydantic import BaseModel, validator from pydantic import BaseModel
# Internal Packages # Internal Packages
from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty from khoj.utils.helpers import to_snake_case_from_dash
class ConfigBase(BaseModel): class ConfigBase(BaseModel):
@ -27,7 +28,7 @@ class TextConfigBase(ConfigBase):
embeddings_file: Path embeddings_file: Path
class TextContentConfig(TextConfigBase): class TextContentConfig(ConfigBase):
input_files: Optional[List[Path]] input_files: Optional[List[Path]]
input_filter: Optional[List[str]] input_filter: Optional[List[str]]
index_heading_entries: Optional[bool] = False index_heading_entries: Optional[bool] = False
@ -39,12 +40,12 @@ class GithubRepoConfig(ConfigBase):
branch: Optional[str] = "master" branch: Optional[str] = "master"
class GithubContentConfig(TextConfigBase): class GithubContentConfig(ConfigBase):
pat_token: str pat_token: str
repos: List[GithubRepoConfig] repos: List[GithubRepoConfig]
class NotionContentConfig(TextConfigBase): class NotionContentConfig(ConfigBase):
token: str token: str
@ -63,7 +64,6 @@ class ContentConfig(ConfigBase):
pdf: Optional[TextContentConfig] pdf: Optional[TextContentConfig]
plaintext: Optional[TextContentConfig] plaintext: Optional[TextContentConfig]
github: Optional[GithubContentConfig] github: Optional[GithubContentConfig]
plugins: Optional[Dict[str, TextContentConfig]]
notion: Optional[NotionContentConfig] notion: Optional[NotionContentConfig]
@ -122,7 +122,8 @@ class FullConfig(ConfigBase):
class SearchResponse(ConfigBase): class SearchResponse(ConfigBase):
entry: str entry: str
score: str score: float
cross_score: Optional[float]
additional: Optional[dict] additional: Optional[dict]
@ -131,14 +132,21 @@ class Entry:
compiled: str compiled: str
heading: Optional[str] heading: Optional[str]
file: Optional[str] file: Optional[str]
corpus_id: str
def __init__( def __init__(
self, raw: str = None, compiled: str = None, heading: Optional[str] = None, file: Optional[str] = None self,
raw: str = None,
compiled: str = None,
heading: Optional[str] = None,
file: Optional[str] = None,
corpus_id: uuid.UUID = None,
): ):
self.raw = raw self.raw = raw
self.compiled = compiled self.compiled = compiled
self.heading = heading self.heading = heading
self.file = file self.file = file
self.corpus_id = str(corpus_id)
def to_json(self) -> str: def to_json(self) -> str:
return json.dumps(self.__dict__, ensure_ascii=False) return json.dumps(self.__dict__, ensure_ascii=False)
@ -153,4 +161,5 @@ class Entry:
compiled=dictionary["compiled"], compiled=dictionary["compiled"],
file=dictionary.get("file", None), file=dictionary.get("file", None),
heading=dictionary.get("heading", None), heading=dictionary.get("heading", None),
corpus_id=dictionary.get("corpus_id", None),
) )

View file

@ -2,6 +2,7 @@
import threading import threading
from typing import List, Dict from typing import List, Dict
from packaging import version from packaging import version
from collections import defaultdict
# External Packages # External Packages
import torch import torch
@ -12,10 +13,13 @@ from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
from khoj.utils.helpers import LRU from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
# Application Global State # Application Global State
config = FullConfig() config = FullConfig()
search_models = SearchModels() search_models = SearchModels()
embeddings_model = EmbeddingsModel()
cross_encoder_model = CrossEncoderModel()
content_index = ContentIndex() content_index = ContentIndex()
processor_config = ProcessorConfigModel() processor_config = ProcessorConfigModel()
config_file: Path = None config_file: Path = None
@ -23,14 +27,14 @@ verbose: int = 0
host: str = None host: str = None
port: int = None port: int = None
cli_args: List[str] = None cli_args: List[str] = None
query_cache = LRU() query_cache: Dict[str, LRU] = defaultdict(LRU)
config_lock = threading.Lock() config_lock = threading.Lock()
chat_lock = threading.Lock() chat_lock = threading.Lock()
SearchType = utils_config.SearchType SearchType = utils_config.SearchType
telemetry: List[Dict[str, str]] = [] telemetry: List[Dict[str, str]] = []
previous_query: str = None
demo: bool = False demo: bool = False
khoj_version: str = None khoj_version: str = None
anonymous_mode: bool = False
if torch.cuda.is_available(): if torch.cuda.is_available():
# Use CUDA GPU # Use CUDA GPU

View file

@ -1,15 +1,19 @@
# External Packages # External Packages
import os import os
from copy import deepcopy
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pathlib import Path from pathlib import Path
import pytest import pytest
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI
import factory
import os
from fastapi import FastAPI
app = FastAPI()
# Internal Packages # Internal Packages
from app.main import app
from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchModels from khoj.utils.config import SearchModels
@ -22,8 +26,6 @@ from khoj.utils.rawconfig import (
OpenAIProcessorConfig, OpenAIProcessorConfig,
ProcessorConfig, ProcessorConfig,
TextContentConfig, TextContentConfig,
GithubContentConfig,
GithubRepoConfig,
ImageContentConfig, ImageContentConfig,
SearchConfig, SearchConfig,
TextSearchConfig, TextSearchConfig,
@ -31,11 +33,31 @@ from khoj.utils.rawconfig import (
) )
from khoj.utils import state, fs_syncer from khoj.utils import state, fs_syncer
from khoj.routers.indexer import configure_content from khoj.routers.indexer import configure_content
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.search_filter.date_filter import DateFilter from database.models import (
from khoj.search_filter.word_filter import WordFilter LocalOrgConfig,
from khoj.search_filter.file_filter import FileFilter LocalMarkdownConfig,
LocalPlaintextConfig,
LocalPdfConfig,
GithubConfig,
KhojUser,
GithubRepoConfig,
)
@pytest.fixture(autouse=True)
def enable_db_access_for_all_tests(db):
pass
class UserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojUser
username = factory.Faker("name")
email = factory.Faker("email")
password = factory.Faker("password")
uuid = factory.Faker("uuid4")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -67,17 +89,28 @@ def search_config() -> SearchConfig:
return search_config return search_config
@pytest.mark.django_db
@pytest.fixture
def default_user():
return UserFactory()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def search_models(search_config: SearchConfig): def search_models(search_config: SearchConfig):
search_models = SearchModels() search_models = SearchModels()
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
search_models.image_search = image_search.initialize_model(search_config.image) search_models.image_search = image_search.initialize_model(search_config.image)
return search_models return search_models
@pytest.fixture(scope="session") @pytest.fixture
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig): def anyio_backend():
return "asyncio"
@pytest.mark.django_db
@pytest.fixture(scope="function")
def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser):
content_dir = tmp_path_factory.mktemp("content") content_dir = tmp_path_factory.mktemp("content")
# Generate Image Embeddings from Test Images # Generate Image Embeddings from Test Images
@ -92,94 +125,45 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False) image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
# Generate Notes Embeddings from Test Notes LocalOrgConfig.objects.create(
content_config.org = TextContentConfig(
input_files=None, input_files=None,
input_filter=["tests/data/org/*.org"], input_filter=["tests/data/org/*.org"],
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"), index_heading_entries=False,
embeddings_file=content_dir.joinpath("note_embeddings.pt"), user=default_user,
) )
filters = [DateFilter(), WordFilter(), FileFilter()] text_search.setup(OrgToJsonl, get_sample_data("org"), regenerate=False, user=default_user)
text_search.setup(
OrgToJsonl,
get_sample_data("org"),
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
content_config.plugins = {
"plugin1": TextContentConfig(
input_files=[content_dir.joinpath("notes.jsonl.gz")],
input_filter=None,
compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
)
}
if os.getenv("GITHUB_PAT_TOKEN"): if os.getenv("GITHUB_PAT_TOKEN"):
content_config.github = GithubContentConfig( GithubConfig.objects.create(
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""), pat_token=os.getenv("GITHUB_PAT_TOKEN"),
repos=[ user=default_user,
GithubRepoConfig(
owner="khoj-ai",
name="lantern",
branch="master",
)
],
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
) )
content_config.plaintext = TextContentConfig( GithubRepoConfig.objects.create(
owner="khoj-ai",
name="lantern",
branch="master",
github_config=GithubConfig.objects.get(user=default_user),
)
LocalPlaintextConfig.objects.create(
input_files=None, input_files=None,
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"], input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
compressed_jsonl=content_dir.joinpath("plaintext.jsonl.gz"), user=default_user,
embeddings_file=content_dir.joinpath("plaintext_embeddings.pt"),
)
content_config.github = GithubContentConfig(
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
repos=[
GithubRepoConfig(
owner="khoj-ai",
name="lantern",
branch="master",
)
],
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
)
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(
JsonlToJsonl,
None,
content_config.plugins["plugin1"],
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
) )
return content_config return content_config
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def md_content_config(tmp_path_factory): def md_content_config():
content_dir = tmp_path_factory.mktemp("content") markdown_config = LocalMarkdownConfig.objects.create(
# Generate Embeddings for Markdown Content
content_config = ContentConfig()
content_config.markdown = TextContentConfig(
input_files=None, input_files=None,
input_filter=["tests/data/markdown/*.markdown"], input_filter=["tests/data/markdown/*.markdown"],
compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"),
embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
) )
return content_config return markdown_config
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -220,19 +204,20 @@ def processor_config_offline_chat(tmp_path_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig): def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
# Initialize app state # Initialize app state
state.config.content_type = md_content_config
state.config.search_type = search_config state.config.search_type = search_config
state.SearchType = configure_search_types(state.config) state.SearchType = configure_search_types(state.config)
# Index Markdown Content for Search # Index Markdown Content for Search
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric) all_files = fs_syncer.collect_files()
all_files = fs_syncer.collect_files(state.config.content_type)
state.content_index = configure_content( state.content_index = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models state.content_index, state.config.content_type, all_files, state.search_models
) )
# Initialize Processor from Config # Initialize Processor from Config
state.processor_config = configure_processor(processor_config) state.processor_config = configure_processor(processor_config)
state.anonymous_mode = True
app = FastAPI()
configure_routes(app) configure_routes(app)
configure_middleware(app) configure_middleware(app)
@ -241,33 +226,45 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def client(content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig): def fastapi_app():
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
return app
@pytest.fixture(scope="function")
def client(
content_config: ContentConfig,
search_config: SearchConfig,
processor_config: ProcessorConfig,
default_user: KhojUser,
):
state.config.content_type = content_config state.config.content_type = content_config
state.config.search_type = search_config state.config.search_type = search_config
state.SearchType = configure_search_types(state.config) state.SearchType = configure_search_types(state.config)
# These lines help us Mock the Search models for these search types # These lines help us Mock the Search models for these search types
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image) state.search_models.image_search = image_search.initialize_model(search_config.image)
state.content_index.org = text_search.setup( text_search.setup(
OrgToJsonl, OrgToJsonl,
get_sample_data("org"), get_sample_data("org"),
content_config.org,
state.search_models.text_search.bi_encoder,
regenerate=False, regenerate=False,
user=default_user,
) )
state.content_index.image = image_search.setup( state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False content_config.image, state.search_models.image_search, regenerate=False
) )
state.content_index.plaintext = text_search.setup( text_search.setup(
PlaintextToJsonl, PlaintextToJsonl,
get_sample_data("plaintext"), get_sample_data("plaintext"),
content_config.plaintext,
state.search_models.text_search.bi_encoder,
regenerate=False, regenerate=False,
user=default_user,
) )
state.processor_config = configure_processor(processor_config) state.processor_config = configure_processor(processor_config)
state.anonymous_mode = True
configure_routes(app) configure_routes(app)
configure_middleware(app) configure_middleware(app)
@ -288,7 +285,6 @@ def client_offline_chat(
state.SearchType = configure_search_types(state.config) state.SearchType = configure_search_types(state.config)
# Index Markdown Content for Search # Index Markdown Content for Search
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image) state.search_models.image_search = image_search.initialize_model(search_config.image)
all_files = fs_syncer.collect_files(state.config.content_type) all_files = fs_syncer.collect_files(state.config.content_type)
@ -298,6 +294,7 @@ def client_offline_chat(
# Initialize Processor from Config # Initialize Processor from Config
state.processor_config = configure_processor(processor_config_offline_chat) state.processor_config = configure_processor(processor_config_offline_chat)
state.anonymous_mode = True
configure_routes(app) configure_routes(app)
configure_middleware(app) configure_middleware(app)
@ -306,9 +303,11 @@ def client_offline_chat(
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def new_org_file(content_config: ContentConfig): def new_org_file(default_user: KhojUser, content_config: ContentConfig):
# Setup # Setup
new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org" org_config = LocalOrgConfig.objects.filter(user=default_user).first()
input_filters = org_config.input_filter
new_org_file = Path(input_filters[0]).parent / "new_file.org"
new_org_file.touch() new_org_file.touch()
yield new_org_file yield new_org_file
@ -319,11 +318,9 @@ def new_org_file(content_config: ContentConfig):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path): def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
new_org_config = deepcopy(content_config.org) LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None)
new_org_config.input_files = [f"{new_org_file}"] return LocalOrgConfig.objects.filter(user=default_user).first()
new_org_config.input_filter = None
return new_org_config
@pytest.fixture(scope="function") @pytest.fixture(scope="function")

11
tests/data/config.yml vendored
View file

@ -9,17 +9,6 @@ content-type:
input-filter: input-filter:
- '*.org' - '*.org'
- ~/notes/*.org - ~/notes/*.org
plugins:
content_plugin_1:
compressed-jsonl: content_plugin_1.jsonl.gz
embeddings-file: content_plugin_1_embeddings.pt
input-files:
- content_plugin_1_new.jsonl.gz
content_plugin_2:
compressed-jsonl: content_plugin_2.jsonl.gz
embeddings-file: content_plugin_2_embeddings.pt
input-filter:
- '*2_new.jsonl.gz'
enable-offline-chat: false enable-offline-chat: false
search-type: search-type:
asymmetric: asymmetric:

View file

@ -48,14 +48,3 @@ def test_cli_config_from_file():
Path("~/first_from_config.org"), Path("~/first_from_config.org"),
Path("~/second_from_config.org"), Path("~/second_from_config.org"),
] ]
assert len(actual_args.config.content_type.plugins.keys()) == 2
assert actual_args.config.content_type.plugins["content_plugin_1"].input_files == [
Path("content_plugin_1_new.jsonl.gz")
]
assert actual_args.config.content_type.plugins["content_plugin_2"].input_filter == ["*2_new.jsonl.gz"]
assert actual_args.config.content_type.plugins["content_plugin_1"].compressed_jsonl == Path(
"content_plugin_1.jsonl.gz"
)
assert actual_args.config.content_type.plugins["content_plugin_2"].embeddings_file == Path(
"content_plugin_2_embeddings.pt"
)

View file

@ -2,22 +2,21 @@
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
from urllib.parse import quote from urllib.parse import quote
import pytest
# External Packages # External Packages
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
import pytest from fastapi import FastAPI
# Internal Packages # Internal Packages
from app.main import app
from khoj.configure import configure_routes, configure_search_types from khoj.configure import configure_routes, configure_search_types
from khoj.utils import state from khoj.utils import state
from khoj.utils.state import search_models, content_index, config from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.search_filter.word_filter import WordFilter from database.models import KhojUser
from khoj.search_filter.file_filter import FileFilter from database.adapters import EmbeddingsAdapters
# Test # Test
@ -35,7 +34,7 @@ def test_search_with_invalid_content_type(client):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_search_with_valid_content_type(client): def test_search_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plugin1"]: for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion"]:
# Act # Act
response = client.get(f"/api/search?q=random&t={content_type}") response = client.get(f"/api/search?q=random&t={content_type}")
# Assert # Assert
@ -75,7 +74,7 @@ def test_index_update(client):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_regenerate_with_valid_content_type(client): def test_regenerate_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]: for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]:
# Arrange # Arrange
files = get_sample_files_data() files = get_sample_files_data()
headers = {"x-api-key": "secret"} headers = {"x-api-key": "secret"}
@ -102,60 +101,42 @@ def test_regenerate_with_github_fails_without_pat(client):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
@pytest.mark.skip(reason="Flaky test on parallel test runs") @pytest.mark.skip(reason="Flaky test on parallel test runs")
def test_get_configured_types_via_api(client): def test_get_configured_types_via_api(client, sample_org_data):
# Act # Act
response = client.get(f"/api/config/types") text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
# Assert # Assert
assert response.status_code == 200 assert list(enabled_types) == ["org"]
assert response.json() == ["all", "org", "image", "plaintext", "plugin1"]
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_get_configured_types_with_only_plugin_content_config(content_config): @pytest.mark.django_db(transaction=True)
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data):
# Arrange # Arrange
config.content_type = ContentConfig() text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
config.content_type.plugins = content_config.plugins
state.SearchType = configure_search_types(config)
configure_routes(app)
client = TestClient(app)
# Act # Act
response = client.get(f"/api/config/types") response = client.get(f"/api/config/types")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == ["all", "plugin1"] assert response.json() == ["all", "org", "markdown", "image"]
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_get_configured_types_with_no_plugin_content_config(content_config): @pytest.mark.django_db(transaction=True)
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
# Arrange # Arrange
config.content_type = content_config
config.content_type.plugins = None
state.SearchType = configure_search_types(config) state.SearchType = configure_search_types(config)
original_config = state.config.content_type
state.config.content_type = None
configure_routes(app) configure_routes(fastapi_app)
client = TestClient(app) client = TestClient(fastapi_app)
# Act
response = client.get(f"/api/config/types")
# Assert
assert response.status_code == 200
assert "plugin1" not in response.json()
# ----------------------------------------------------------------------------------------------------
def test_get_configured_types_with_no_content_config():
# Arrange
config.content_type = ContentConfig()
state.SearchType = configure_search_types(config)
configure_routes(app)
client = TestClient(app)
# Act # Act
response = client.get(f"/api/config/types") response = client.get(f"/api/config/types")
@ -164,6 +145,9 @@ def test_get_configured_types_with_no_content_config():
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == ["all"] assert response.json() == ["all"]
# Restore
state.config.content_type = original_config
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig): def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
@ -192,12 +176,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data): @pytest.mark.django_db(transaction=True)
def test_notes_search(client, search_config: SearchConfig, sample_org_data):
# Arrange # Arrange
search_models.text_search = text_search.initialize_model(search_config.asymmetric) text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
content_index.org = text_search.setup(
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
user_query = quote("How to git install application?") user_query = quote("How to git install application?")
# Act # Act
@ -211,19 +193,15 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_only_filters( def test_notes_search_with_only_filters(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
): ):
# Arrange # Arrange
filters = [WordFilter(), FileFilter()] text_search.setup(
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, OrgToJsonl,
sample_org_data, sample_org_data,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False, regenerate=False,
filters=filters,
) )
user_query = quote('+"Emacs" file:"*.org"') user_query = quote('+"Emacs" file:"*.org"')
@ -238,15 +216,10 @@ def test_notes_search_with_only_filters(
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_notes_search_with_include_filter( @pytest.mark.django_db(transaction=True)
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data def test_notes_search_with_include_filter(client, sample_org_data):
):
# Arrange # Arrange
filters = [WordFilter()] text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search, regenerate=False, filters=filters
)
user_query = quote('How to git install application? +"Emacs"') user_query = quote('How to git install application? +"Emacs"')
# Act # Act
@ -260,19 +233,13 @@ def test_notes_search_with_include_filter(
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_notes_search_with_exclude_filter( @pytest.mark.django_db(transaction=True)
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data def test_notes_search_with_exclude_filter(client, sample_org_data):
):
# Arrange # Arrange
filters = [WordFilter()] text_search.setup(
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, OrgToJsonl,
sample_org_data, sample_org_data,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False, regenerate=False,
filters=filters,
) )
user_query = quote('How to git install application? -"clone"') user_query = quote('How to git install application? -"clone"')
@ -286,6 +253,22 @@ def test_notes_search_with_exclude_filter(
assert "clone" not in search_result assert "clone" not in search_result
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
# Arrange
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
user_query = quote("How to git install application?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
# Assert
assert response.status_code == 200
# assert actual response has no data as the default_user is different from the user making the query (anonymous)
assert len(response.json()) == 0
def get_sample_files_data(): def get_sample_files_data():
return { return {
"files": ("path/to/filename.org", "* practicing piano", "text/org"), "files": ("path/to/filename.org", "* practicing piano", "text/org"),

View file

@ -1,53 +1,12 @@
# Standard Packages # Standard Packages
import re import re
from datetime import datetime from datetime import datetime
from math import inf
# External Packages # External Packages
import pytest import pytest
# Internal Packages # Internal Packages
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.utils.rawconfig import Entry
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
def test_date_filter():
entries = [
Entry(compiled="Entry with no date", raw="Entry with no date"),
Entry(compiled="April Fools entry: 1984-04-01", raw="April Fools entry: 1984-04-01"),
Entry(compiled="Entry with date:1984-04-02", raw="Entry with date:1984-04-02"),
]
q_with_no_date_filter = "head tail"
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2}
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
assert ret_query == "head tail"
assert entry_indices == set()
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {1}
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {1, 2}
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.") @pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
@ -56,8 +15,8 @@ def test_extract_date_range():
datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 5, 0, 0, 0).timestamp(),
datetime(1984, 1, 7, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp(),
] ]
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [None, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf] assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), None]
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [ assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 1, 0, 0, 0).timestamp(),
datetime(1984, 1, 2, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp(),

View file

@ -6,97 +6,73 @@ from khoj.utils.rawconfig import Entry
def test_no_file_filter(): def test_no_file_filter():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = "head tail" q_with_no_filter = "head tail"
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert # Assert
assert can_filter == False assert can_filter == False
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_file_filter_with_non_existent_file(): def test_file_filter_with_non_existent_file():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"nonexistent.org" tail' q_with_no_filter = 'head file:"nonexistent.org" tail'
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {}
def test_single_file_filter(): def test_single_file_filter():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"file 1.org" tail' q_with_no_filter = 'head file:"file 1.org" tail'
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 2}
def test_file_filter_with_partial_match(): def test_file_filter_with_partial_match():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"1.org" tail' q_with_no_filter = 'head file:"1.org" tail'
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 2}
def test_file_filter_with_regex_match(): def test_file_filter_with_regex_match():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"*.org" tail' q_with_no_filter = 'head file:"*.org" tail'
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_multiple_file_filter(): def test_multiple_file_filter():
# Arrange # Arrange
file_filter = FileFilter() file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"' q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_get_file_filter_terms(): def test_get_file_filter_terms():
@ -108,7 +84,7 @@ def test_get_file_filter_terms():
filter_terms = file_filter.get_filter_terms(q_with_filter_terms) filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
# Assert # Assert
assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"'] assert filter_terms == ["file 1\\.org", "/path/to/dir/.*\\.org"]
def arrange_content(): def arrange_content():

View file

@ -1,78 +0,0 @@
# Internal Packages
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.utils.rawconfig import Entry
def test_process_entries_from_single_input_jsonl(tmp_path):
"Convert multiple jsonl entries from single file to entries."
# Arrange
input_jsonl = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}
{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}
"""
input_jsonl_file = create_file(tmp_path, input_jsonl)
# Act
# Process Each Entry from All Notes Files
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file])
entries = list(map(Entry.from_dict, input_jsons))
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
# Assert
assert len(entries) == 2
assert output_jsonl == input_jsonl
def test_process_entries_from_multiple_input_jsonls(tmp_path):
"Convert multiple jsonl entries from single file to entries."
# Arrange
input_jsonl_1 = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}"""
input_jsonl_2 = """{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}"""
input_jsonl_file_1 = create_file(tmp_path, input_jsonl_1, filename="input1.jsonl")
input_jsonl_file_2 = create_file(tmp_path, input_jsonl_2, filename="input2.jsonl")
# Act
# Process Each Entry from All Notes Files
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file_1, input_jsonl_file_2])
entries = list(map(Entry.from_dict, input_jsons))
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
# Assert
assert len(entries) == 2
assert output_jsonl == f"{input_jsonl_1}\n{input_jsonl_2}\n"
def test_get_jsonl_files(tmp_path):
"Ensure JSONL files specified via input-filter, input-files extracted"
# Arrange
# Include via input-filter globs
group1_file1 = create_file(tmp_path, filename="group1-file1.jsonl")
group1_file2 = create_file(tmp_path, filename="group1-file2.jsonl")
group2_file1 = create_file(tmp_path, filename="group2-file1.jsonl")
group2_file2 = create_file(tmp_path, filename="group2-file2.jsonl")
# Include via input-file field
file1 = create_file(tmp_path, filename="notes.jsonl")
# Not included by any filter
create_file(tmp_path, filename="not-included-jsonl.jsonl")
create_file(tmp_path, filename="not-included-text.txt")
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
# Setup input-files, input-filters
input_files = [tmp_path / "notes.jsonl"]
input_filter = [tmp_path / "group1*.jsonl", tmp_path / "group2*.jsonl"]
# Act
extracted_org_files = JsonlToJsonl.get_jsonl_files(input_files, input_filter)
# Assert
assert len(extracted_org_files) == 5
assert extracted_org_files == expected_files
# Helper Functions
def create_file(tmp_path, entry=None, filename="test.jsonl"):
jsonl_file = tmp_path / filename
jsonl_file.touch()
if entry:
jsonl_file.write_text(entry)
return jsonl_file

View file

@ -4,7 +4,7 @@ import os
# Internal Packages # Internal Packages
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.text_to_jsonl import TextToJsonl from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import is_none_or_empty from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from khoj.utils.fs_syncer import get_org_files from khoj.utils.fs_syncer import get_org_files
@ -63,7 +63,7 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
# Split each entry from specified Org files by max words # Split each entry from specified Org files by max words
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl( jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
TextToJsonl.split_entries_by_max_tokens( TextEmbeddings.split_entries_by_max_tokens(
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4 OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
) )
) )
@ -86,7 +86,7 @@ def test_entry_split_drops_large_words():
# Act # Act
# Split entry by max words and drop words larger than max word length # Split entry by max words and drop words larger than max word length
processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0] processed_entry = TextEmbeddings.split_entries_by_max_tokens([entry], max_word_length=5)[0]
# Assert # Assert
# "Heading" dropped from compiled version because its over the set max word limit # "Heading" dropped from compiled version because its over the set max word limit

View file

@ -7,6 +7,7 @@ from pathlib import Path
from khoj.utils.fs_syncer import get_plaintext_files from khoj.utils.fs_syncer import get_plaintext_files
from khoj.utils.rawconfig import TextContentConfig from khoj.utils.rawconfig import TextContentConfig
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from database.models import LocalPlaintextConfig, KhojUser
def test_plaintext_file(tmp_path): def test_plaintext_file(tmp_path):
@ -91,11 +92,12 @@ def test_get_plaintext_files(tmp_path):
assert set(extracted_plaintext_files.keys()) == set(expected_files) assert set(extracted_plaintext_files.keys()) == set(expected_files)
def test_parse_html_plaintext_file(content_config): def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
"Ensure HTML files are parsed correctly" "Ensure HTML files are parsed correctly"
# Arrange # Arrange
# Setup input-files, input-filters # Setup input-files, input-filters
extracted_plaintext_files = get_plaintext_files(content_config.plaintext) config = LocalPlaintextConfig.objects.filter(user=default_user).first()
extracted_plaintext_files = get_plaintext_files(config=config)
# Act # Act
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files) maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files)

View file

@ -3,23 +3,30 @@ import logging
import locale import locale
from pathlib import Path from pathlib import Path
import os import os
import asyncio
# External Packages # External Packages
import pytest import pytest
# Internal Packages # Internal Packages
from khoj.utils.state import content_index, search_models
from khoj.search_type import text_search from khoj.search_type import text_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.utils.config import SearchModels from khoj.utils.config import SearchModels
from khoj.utils.fs_syncer import get_org_files from khoj.utils.fs_syncer import get_org_files, collect_files
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
logger = logging.getLogger(__name__)
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
# Test # Test
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig): @pytest.mark.django_db
def test_text_search_setup_with_missing_file_raises_error(
org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig
):
# Arrange # Arrange
# Ensure file mentioned in org.input-files is missing # Ensure file mentioned in org.input-files is missing
single_new_file = Path(org_config_with_only_new_file.input_files[0]) single_new_file = Path(org_config_with_only_new_file.input_files[0])
@ -32,98 +39,126 @@ def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_n
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path: Path): @pytest.mark.django_db
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser):
# Arrange # Arrange
orgfile = tmp_path / "directory.org" / "file.org" orgfile = tmp_path / "directory.org" / "file.org"
orgfile.parent.mkdir() orgfile.parent.mkdir()
with open(orgfile, "w") as f: with open(orgfile, "w") as f:
f.write("* Heading\n- List item\n") f.write("* Heading\n- List item\n")
org_content_config = TextContentConfig(
input_filter=[f"{tmp_path}/**/*"], compressed_jsonl="test.jsonl", embeddings_file="test.pt" LocalOrgConfig.objects.create(
input_filter=[f"{tmp_path}/**/*"],
input_files=None,
user=default_user,
) )
org_files = collect_files(user=default_user)["org"]
# Act # Act
# should not raise IsADirectoryError and return orgfile # should not raise IsADirectoryError and return orgfile
assert get_org_files(org_content_config) == {f"{orgfile}": "* Heading\n- List item\n"} assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_with_empty_file_raises_error( def test_text_search_setup_with_empty_file_raises_error(
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
): ):
# Arrange # Arrange
data = get_org_files(org_config_with_only_new_file) data = get_org_files(org_config_with_only_new_file)
# Act # Act
# Generate notes embeddings during asymmetric setup # Generate notes embeddings during asymmetric setup
with pytest.raises(ValueError, match=r"^No valid entries found*"): with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, org_config_with_only_new_file, search_config.asymmetric, regenerate=True) text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
verify_embeddings(0, default_user)
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels): @pytest.mark.django_db
def test_text_search_setup(content_config, default_user: KhojUser, caplog):
# Arrange # Arrange
data = get_org_files(content_config.org) org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
# Act with caplog.at_level(logging.INFO):
# Regenerate notes embeddings during asymmetric setup text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
# Assert # Assert
assert len(notes_model.entries) == 10 assert "Deleting all embeddings for file type org" in caplog.records[1].message
assert len(notes_model.corpus_embeddings) == 10 assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog): @pytest.mark.django_db
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
# Arrange # Arrange
caplog.set_level(logging.INFO, logger="khoj") org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
data = get_org_files(content_config.org)
# Act # Act
# Generate initial notes embeddings during asymmetric setup # Generate initial notes embeddings during asymmetric setup
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True) with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
initial_logs = caplog.text initial_logs = caplog.text
caplog.clear() # Clear logs caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated # Run asymmetric setup again with no changes to data source. Ensure index is not updated
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False) with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
final_logs = caplog.text final_logs = caplog.text
# Assert # Assert
assert "Creating index from scratch." in initial_logs assert "Deleting all embeddings for file type org" in initial_logs
assert "Creating index from scratch." not in final_logs assert "Deleting all embeddings for file type org" not in final_logs
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
@pytest.mark.anyio @pytest.mark.anyio
async def test_text_search(content_config: ContentConfig, search_config: SearchConfig): # @pytest.mark.asyncio
async def test_text_search(search_config: SearchConfig):
# Arrange # Arrange
data = get_org_files(content_config.org) default_user = await KhojUser.objects.acreate(
username="test_user", password="test_password", email="test@example.com"
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
) )
# Arrange
org_config = await LocalOrgConfig.objects.acreate(
input_files=None,
input_filter=["tests/data/org/*.org"],
index_heading_entries=False,
user=default_user,
)
data = get_org_files(org_config)
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
text_search.setup,
OrgToJsonl,
data,
True,
True,
default_user,
)
query = "How to git install application?" query = "How to git install application?"
# Act # Act
hits, entries = await text_search.query( hits = await text_search.query(default_user, query)
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
)
results = text_search.collate_results(hits, entries, count=1)
# Assert # Assert
results = text_search.collate_results(hits)
results = sorted(results, key=lambda x: float(x.score))[:1]
# search results should contain "git clone" entry # search results should contain "git clone" entry
search_result = results[0].entry search_result = results[0].entry
assert "git clone" in search_result assert "git clone" in search_result
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels): @pytest.mark.django_db
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
# Arrange # Arrange
# Insert org-mode entry with size exceeding max token limit to new org file # Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256 max_tokens = 256
@ -137,47 +172,46 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# Act # Act
# reload embeddings, entries, notes model after adding new org-mode file # reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup( with caplog.at_level(logging.INFO):
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
)
# Assert # Assert
# verify newly added org-mode entry is split by max tokens # verify newly added org-mode entry is split by max tokens
assert len(initial_notes_model.entries) == 2 record = caplog.records[1]
assert len(initial_notes_model.corpus_embeddings) == 2 assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests") # @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
@pytest.mark.django_db
def test_entry_chunking_by_max_tokens_not_full_corpus( def test_entry_chunking_by_max_tokens_not_full_corpus(
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
): ):
# Arrange # Arrange
# Insert org-mode entry with size exceeding max token limit to new org file # Insert org-mode entry with size exceeding max token limit to new org file
data = { data = {
"readme.org": """ "readme.org": """
* Khoj * Khoj
/Allow natural language search on user content like notes, images using transformer based models/ /Allow natural language search on user content like notes, images using transformer based models/
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
** Dependencies ** Dependencies
- Python3 - Python3
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]] - [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
** Install ** Install
#+begin_src shell #+begin_src shell
git clone https://github.com/khoj-ai/khoj && cd khoj git clone https://github.com/khoj-ai/khoj && cd khoj
conda env create -f environment.yml conda env create -f environment.yml
conda activate khoj conda activate khoj
#+end_src""" #+end_src"""
} }
text_search.setup( text_search.setup(
OrgToJsonl, OrgToJsonl,
data, data,
org_config_with_only_new_file,
search_models.text_search.bi_encoder,
regenerate=False, regenerate=False,
user=default_user,
) )
max_tokens = 256 max_tokens = 256
@ -191,64 +225,57 @@ def test_entry_chunking_by_max_tokens_not_full_corpus(
# Act # Act
# reload embeddings, entries, notes model after adding new org-mode file # reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup( with caplog.at_level(logging.INFO):
OrgToJsonl, text_search.setup(
data, OrgToJsonl,
org_config_with_only_new_file, data,
search_models.text_search.bi_encoder, regenerate=False,
regenerate=False, full_corpus=False,
full_corpus=False, user=default_user,
) )
record = caplog.records[1]
# Assert # Assert
# verify newly added org-mode entry is split by max tokens # verify newly added org-mode entry is split by max tokens
assert len(initial_notes_model.entries) == 5 assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
assert len(initial_notes_model.corpus_embeddings) == 5
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_regenerate_index_with_new_entry( def test_regenerate_index_with_new_entry(
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
): ):
# Arrange # Arrange
data = get_org_files(content_config.org) org_config = LocalOrgConfig.objects.filter(user=default_user).first()
initial_notes_model = text_search.setup( data = get_org_files(org_config)
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10 with caplog.at_level(logging.INFO):
assert len(initial_notes_model.corpus_embeddings) == 10 text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
# append org-mode entry to first org input file in config # append org-mode entry to first org input file in config
content_config.org.input_files = [f"{new_org_file}"] org_config.input_files = [f"{new_org_file}"]
with open(new_org_file, "w") as f: with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
data = get_org_files(content_config.org) data = get_org_files(org_config)
# Act # Act
# regenerate notes jsonl, model embeddings and model to include entry from new file # regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup( with caplog.at_level(logging.INFO):
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
)
# Assert # Assert
assert len(regenerated_notes_model.entries) == 11 assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message
assert len(regenerated_notes_model.corpus_embeddings) == 11 verify_embeddings(11, default_user)
# verify new entry appended to index, without disrupting order or content of existing entries
error_details = compare_index(initial_notes_model, regenerated_notes_model)
if error_details:
pytest.fail(error_details, False)
# Cleanup
# reset input_files in config to empty list
content_config.org.input_files = []
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_update_index_with_duplicate_entries_in_stable_order( def test_update_index_with_duplicate_entries_in_stable_order(
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
): ):
# Arrange # Arrange
new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
@ -262,30 +289,26 @@ def test_update_index_with_duplicate_entries_in_stable_order(
# Act # Act
# load embeddings, entries, notes model after adding new org-mode file # load embeddings, entries, notes model after adding new org-mode file
initial_index = text_search.setup( with caplog.at_level(logging.INFO):
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
)
data = get_org_files(org_config_with_only_new_file) data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model after adding new org-mode file # update embeddings, entries, notes model after adding new org-mode file
updated_index = text_search.setup( with caplog.at_level(logging.INFO):
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
)
# Assert # Assert
# verify only 1 entry added even if there are multiple duplicate entries # verify only 1 entry added even if there are multiple duplicate entries
assert len(initial_index.entries) == len(updated_index.entries) == 1 assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1 assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
# verify the same entry is added even when there are multiple duplicate entries verify_embeddings(1, default_user)
error_details = compare_index(initial_index, updated_index)
if error_details:
pytest.fail(error_details)
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels): @pytest.mark.django_db
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
# Arrange # Arrange
new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
@ -296,9 +319,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
data = get_org_files(org_config_with_only_new_file) data = get_org_files(org_config_with_only_new_file)
# load embeddings, entries, notes model after adding new org file with 2 entries # load embeddings, entries, notes model after adding new org file with 2 entries
initial_index = text_search.setup( with caplog.at_level(logging.INFO):
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
)
# update embeddings, entries, notes model after removing an entry from the org file # update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f: with open(new_file_to_index, "w") as f:
@ -307,87 +329,65 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
data = get_org_files(org_config_with_only_new_file) data = get_org_files(org_config_with_only_new_file)
# Act # Act
updated_index = text_search.setup( with caplog.at_level(logging.INFO):
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
)
# Assert # Assert
# verify only 1 entry added even if there are multiple duplicate entries # verify only 1 entry added even if there are multiple duplicate entries
assert len(initial_index.entries) == len(updated_index.entries) + 1 assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1 assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message
# verify the same entry is added even when there are multiple duplicate entries verify_embeddings(1, default_user)
error_details = compare_index(updated_index, initial_index)
if error_details:
pytest.fail(error_details)
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path): @pytest.mark.django_db
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
# Arrange # Arrange
data = get_org_files(content_config.org) org_config = LocalOrgConfig.objects.filter(user=default_user).first()
initial_notes_model = text_search.setup( data = get_org_files(org_config)
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False with caplog.at_level(logging.INFO):
) text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# append org-mode entry to first org input file in config # append org-mode entry to first org input file in config
with open(new_org_file, "w") as f: with open(new_org_file, "w") as f:
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
f.write(new_entry) f.write(new_entry)
data = get_org_files(content_config.org) data = get_org_files(org_config)
# Act # Act
# update embeddings, entries with the newly added note # update embeddings, entries with the newly added note
content_config.org.input_files = [f"{new_org_file}"] with caplog.at_level(logging.INFO):
final_notes_model = text_search.setup( text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
)
# Assert # Assert
assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1 assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1 assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
# verify new entry appended to index, without disrupting order or content of existing entries verify_embeddings(11, default_user)
error_details = compare_index(initial_notes_model, final_notes_model)
if error_details:
pytest.fail(error_details, False)
# Cleanup
# reset input_files in config to empty list
content_config.org.input_files = []
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set") @pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
def test_text_search_setup_github(content_config: ContentConfig, search_models: SearchModels): def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
# Arrange
github_config = GithubConfig.objects.filter(user=default_user).first()
# Act # Act
# Regenerate github embeddings to test asymmetric setup without caching # Regenerate github embeddings to test asymmetric setup without caching
github_model = text_search.setup( text_search.setup(
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True GithubToJsonl,
{},
regenerate=True,
user=default_user,
config=github_config,
) )
# Assert # Assert
assert len(github_model.entries) > 1 embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count()
assert embeddings > 1
def compare_index(initial_notes_model, final_notes_model): def verify_embeddings(expected_count, user):
mismatched_entries, mismatched_embeddings = [], [] embeddings = Embeddings.objects.filter(user=user, file_type="org").count()
for index in range(len(initial_notes_model.entries)): assert embeddings == expected_count
if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
mismatched_entries.append(index)
# verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
for index in range(len(initial_notes_model.corpus_embeddings)):
if not initial_notes_model.corpus_embeddings[index].allclose(final_notes_model.corpus_embeddings[index]):
mismatched_embeddings.append(index)
error_details = ""
if mismatched_entries:
mismatched_entries_str = ",".join(map(str, mismatched_entries))
error_details += f"Entries at {mismatched_entries_str} not equal\n"
if mismatched_embeddings:
mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
return error_details

View file

@ -3,68 +3,40 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
def test_no_word_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
q_with_no_filter = "head tail"
# Act
can_filter = word_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == False
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_word_exclude_filter(): def test_word_exclude_filter():
# Arrange # Arrange
word_filter = WordFilter() word_filter = WordFilter()
entries = arrange_content()
q_with_exclude_filter = 'head -"exclude_word" tail' q_with_exclude_filter = 'head -"exclude_word" tail'
# Act # Act
can_filter = word_filter.can_filter(q_with_exclude_filter) can_filter = word_filter.can_filter(q_with_exclude_filter)
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 2}
def test_word_include_filter(): def test_word_include_filter():
# Arrange # Arrange
word_filter = WordFilter() word_filter = WordFilter()
entries = arrange_content()
query_with_include_filter = 'head +"include_word" tail' query_with_include_filter = 'head +"include_word" tail'
# Act # Act
can_filter = word_filter.can_filter(query_with_include_filter) can_filter = word_filter.can_filter(query_with_include_filter)
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {2, 3}
def test_word_include_and_exclude_filter(): def test_word_include_and_exclude_filter():
# Arrange # Arrange
word_filter = WordFilter() word_filter = WordFilter()
entries = arrange_content()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail' query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act # Act
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter) can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {2}
def test_get_word_filter_terms(): def test_get_word_filter_terms():