[Multi-User Part 1]: Enable storage of settings for plaintext files based on user account (#498)

- Partition configuration for indexing local data based on user accounts
- Store indexed data in an underlying postgres db using the `pgvector` extension
- Add migrations for all relevant user data and embeddings generation. Very little performance optimization has been done for the lookup time
- Apply filters using SQL queries
- Start removing many server-level configuration settings
- Configure GitHub test actions to run during any PR. Update the test action to run in a containerized environment with a DB.
- Update the Docker image and docker-compose.yml to work with the new application design
This commit is contained in:
sabaimran 2023-10-26 09:42:29 -07:00 committed by GitHub
parent 963cd165eb
commit 216acf545f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
60 changed files with 1827 additions and 1792 deletions

48
.github/workflows/pre-commit.yml vendored Normal file
View file

@ -0,0 +1,48 @@
name: pre-commit
on:
pull_request:
paths:
- src/**
- tests/**
- config/**
- pyproject.toml
- .pre-commit-config.yml
- .github/workflows/test.yml
push:
branches:
- master
paths:
- src/khoj/**
- tests/**
- config/**
- pyproject.toml
- .pre-commit-config.yml
- .github/workflows/test.yml
jobs:
test:
name: Run Tests
runs-on: ubuntu-latest
strategy:
fail-fast: false
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: ⏬️ Install Dependencies
run: |
sudo apt update && sudo apt install -y libegl1
python -m pip install --upgrade pip
- name: ⬇️ Install Application
run: pip install --upgrade .[dev]
- name: 🌡️ Validate Application
run: pre-commit run --hook-stage manual --all

View file

@ -2,10 +2,8 @@ name: test
on:
pull_request:
branches:
- 'master'
paths:
- src/khoj/**
- src/**
- tests/**
- config/**
- pyproject.toml
@ -13,7 +11,7 @@ on:
- .github/workflows/test.yml
push:
branches:
- 'master'
- master
paths:
- src/khoj/**
- tests/**
@ -26,6 +24,7 @@ jobs:
test:
name: Run Tests
runs-on: ubuntu-latest
container: ubuntu:jammy
strategy:
fail-fast: false
matrix:
@ -33,6 +32,17 @@ jobs:
- '3.9'
- '3.10'
- '3.11'
services:
postgres:
image: ankane/pgvector
env:
POSTGRES_PASSWORD: postgres
POSTGRES_USER: postgres
ports:
- 5432:5432
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
steps:
- uses: actions/checkout@v3
with:
@ -43,17 +53,37 @@ jobs:
with:
python-version: ${{ matrix.python_version }}
- name: ⏬️ Install Dependencies
- name: Install Git
run: |
sudo apt update && sudo apt install -y libegl1
apt update && apt install -y git
- name: ⏬️ Install Dependencies
env:
DEBIAN_FRONTEND: noninteractive
run: |
apt update && apt install -y libegl1 sqlite3 libsqlite3-dev libsqlite3-0
- name: ⬇️ Install Postgres
env:
DEBIAN_FRONTEND: noninteractive
run : |
apt install -y postgresql postgresql-client && apt install -y postgresql-server-dev-14
- name: ⬇️ Install pip
run: |
apt install -y python3-pip
python -m ensurepip --upgrade
python -m pip install --upgrade pip
- name: ⬇️ Install Application
run: pip install --upgrade .[dev]
- name: 🌡️ Validate Application
run: pre-commit run --hook-stage manual --all
run: sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && pip install --upgrade .[dev]
- name: 🧪 Test Application
env:
POSTGRES_HOST: postgres
POSTGRES_PORT: 5432
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
run: pytest
timeout-minutes: 10

View file

@ -8,13 +8,20 @@ RUN apt update -y && apt -y install python3-pip git
WORKDIR /app
# Install Application
COPY . .
COPY pyproject.toml .
COPY README.md .
RUN sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && \
pip install --no-cache-dir .
# Copy Source Code
COPY . .
# Set the PYTHONPATH environment variable in order for it to find the Django app.
ENV PYTHONPATH=/app/src:$PYTHONPATH
# Run the Application
# There are more arguments required for the application to run,
# but these should be passed in through the docker-compose.yml file.
ARG PORT
EXPOSE ${PORT}
ENTRYPOINT ["khoj"]
ENTRYPOINT ["python3", "src/khoj/main.py"]

View file

@ -1,7 +1,21 @@
version: "3.9"
services:
database:
image: ankane/pgvector
ports:
- "5432:5432"
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
volumes:
- khoj_db:/var/lib/postgresql/data/
server:
# Use the following line to use the latest version of khoj. Otherwise, it will build from source.
image: ghcr.io/khoj-ai/khoj:latest
# Uncomment the following line to build from source. This will take a few minutes. Comment the next two lines out if you want to use the offiicial image.
# build:
# context: .
ports:
# If changing the local port (left hand side), no other changes required.
# If changing the remote port (right hand side),
@ -26,8 +40,15 @@ services:
- ./tests/data/models/:/root/.khoj/search/
- khoj_config:/root/.khoj/
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
environment:
- POSTGRES_DB=postgres
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=postgres
- POSTGRES_HOST=database
- POSTGRES_PORT=5432
command: --host="0.0.0.0" --port=42110 -vv
volumes:
khoj_config:
khoj_db:

View file

@ -59,13 +59,15 @@ dependencies = [
"requests >= 2.26.0",
"bs4 >= 0.0.1",
"anyio == 3.7.1",
"pymupdf >= 1.23.3",
"pymupdf >= 1.23.5",
"django == 4.2.5",
"authlib == 1.2.1",
"gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'",
"gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'",
"itsdangerous == 2.1.2",
"httpx == 0.25.0",
"pgvector == 0.2.3",
"psycopg2-binary == 2.9.9",
]
dynamic = ["version"]
@ -91,6 +93,8 @@ dev = [
"mypy >= 1.0.1",
"black >= 23.1.0",
"pre-commit >= 3.0.4",
"pytest-django == 4.5.2",
"pytest-asyncio == 0.21.1",
]
[tool.hatch.version]

4
pytest.ini Normal file
View file

@ -0,0 +1,4 @@
[pytest]
DJANGO_SETTINGS_MODULE = app.settings
pythonpath = . src
testpaths = tests

60
src/app/README.md Normal file
View file

@ -0,0 +1,60 @@
# Django App
Khoj uses Django as the backend framework primarily for its powerful ORM and the admin interface. The Django app is located in the `src/app` directory. We have one installed app, under the `/database/` directory. This app is responsible for all the database related operations and holds all of our models. You can find the extensive Django documentation [here](https://docs.djangoproject.com/en/4.2/) 🌈.
## Setup (Docker)
### Prerequisites
1. Ensure you have [Docker](https://docs.docker.com/get-docker/) installed.
2. Ensure you have [Docker Compose](https://docs.docker.com/compose/install/) installed.
### Run
Using the `docker-compose.yml` file in the root directory, you can run the Khoj app using the following command:
```bash
docker-compose up
```
## Setup (Local)
### Install dependencies
```bash
pip install -e '.[dev]'
```
### Setup the database
1. Ensure you have Postgres installed. For MacOS, you can use [Postgres.app](https://postgresapp.com/).
2. If you're not using Postgres.app, you may have to install the pgvector extension manually. You can find the instructions [here](https://github.com/pgvector/pgvector#installation). If you're using Postgres.app, you can skip this step. Reproduced instructions below for convenience.
```bash
cd /tmp
git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git
cd pgvector
make
make install # may need sudo
```
3. Create a database
### Make migrations
This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted).
```bash
python3 src/manage.py makemigrations
```
### Run migrations
This command will run any pending migrations in your application.
```bash
python3 src/manage.py migrate
```
### Run the server
While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend.
```bash
python3 src/khoj/main.py
```

View file

@ -77,8 +77,12 @@ WSGI_APPLICATION = "app.wsgi.application"
DATABASES = {
"default": {
"ENGINE": "django.db.backends.sqlite3",
"NAME": BASE_DIR / "db.sqlite3",
"ENGINE": "django.db.backends.postgresql",
"HOST": os.getenv("POSTGRES_HOST", "localhost"),
"PORT": os.getenv("POSTGRES_PORT", "5432"),
"USER": os.getenv("POSTGRES_USER", "postgres"),
"NAME": os.getenv("POSTGRES_DB", "khoj"),
"PASSWORD": os.getenv("POSTGRES_PASSWORD", "postgres"),
}
}

View file

@ -1,15 +1,30 @@
from typing import Type, TypeVar
from typing import Type, TypeVar, List
import uuid
from datetime import date
from django.db import models
from django.contrib.sessions.backends.db import SessionStore
from pgvector.django import CosineDistance
from django.db.models.manager import BaseManager
from django.db.models import Q
from torch import Tensor
# Import sync_to_async from Django Channels
from asgiref.sync import sync_to_async
from fastapi import HTTPException
from database.models import KhojUser, GoogleUser, NotionConfig
from database.models import (
KhojUser,
GoogleUser,
NotionConfig,
GithubConfig,
Embeddings,
GithubRepoConfig,
)
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.date_filter import DateFilter
ModelType = TypeVar("ModelType", bound=models.Model)
@ -40,9 +55,7 @@ async def get_or_create_user(token: dict) -> KhojUser:
async def create_google_user(token: dict) -> KhojUser:
user_info = token.get("userinfo")
user = await KhojUser.objects.acreate(
username=user_info.get("email"), email=user_info.get("email"), uuid=uuid.uuid4()
)
user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
await user.asave()
await GoogleUser.objects.acreate(
sub=user_info.get("sub"),
@ -76,3 +89,149 @@ async def retrieve_user(session_id: str) -> KhojUser:
if not user:
raise HTTPException(status_code=401, detail="Invalid user")
return user
def get_all_users() -> BaseManager[KhojUser]:
return KhojUser.objects.all()
def get_user_github_config(user: KhojUser):
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if not config:
return None
return config
def get_user_notion_config(user: KhojUser):
config = NotionConfig.objects.filter(user=user).first()
if not config:
return None
return config
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
await object.objects.filter(user=user).adelete()
await object.objects.acreate(
input_files=deduped_files,
input_filter=deduped_filters,
index_heading_entries=updated_config.index_heading_entries,
user=user,
)
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
config = await GithubConfig.objects.filter(user=user).afirst()
if not config:
config = await GithubConfig.objects.acreate(pat_token=pat_token, user=user)
else:
config.pat_token = pat_token
await config.asave()
await config.githubrepoconfig.all().adelete()
for repo in repos:
await GithubRepoConfig.objects.acreate(
name=repo["name"], owner=repo["owner"], branch=repo["branch"], github_config=config
)
return config
class EmbeddingsAdapters:
word_filer = WordFilter()
file_filter = FileFilter()
date_filter = DateFilter()
@staticmethod
def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool:
return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod
def delete_embedding_by_file(user: KhojUser, file_path: str):
deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete()
return deleted_count
@staticmethod
def delete_all_embeddings(user: KhojUser, file_type: str):
deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete()
return deleted_count
@staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod
def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]):
Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@staticmethod
def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date):
return embeddings.filter(
embeddingsdates__date__gte=start_date,
embeddingsdates__date__lte=end_date,
)
@staticmethod
async def user_has_embeddings(user: KhojUser):
return await Embeddings.objects.filter(user=user).aexists()
@staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
q_filter_terms = Q()
explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query)
file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query)
date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query)
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Embeddings.objects.filter(user=user)
for term in explicit_word_terms:
if term.startswith("+"):
q_filter_terms &= Q(raw__icontains=term[1:])
elif term.startswith("-"):
q_filter_terms &= ~Q(raw__icontains=term[1:])
q_file_filter_terms = Q()
if len(file_filters) > 0:
for term in file_filters:
q_file_filter_terms |= Q(file_path__regex=term)
q_filter_terms &= q_file_filter_terms
if len(date_filters) > 0:
min_date, max_date = date_filters
if min_date is not None:
# Convert the min_date timestamp to yyyy-mm-dd format
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
if max_date is not None:
# Convert the max_date timestamp to yyyy-mm-dd format
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
relevant_embeddings = Embeddings.objects.filter(user=user).filter(
q_filter_terms,
)
if file_type_filter:
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
return relevant_embeddings
@staticmethod
def search_with_embeddings(
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
):
relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_embeddings = relevant_embeddings.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings)
)
if file_type_filter:
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
relevant_embeddings = relevant_embeddings.order_by("distance")
return relevant_embeddings[:max_results]
@staticmethod
def get_unique_file_types(user: KhojUser):
return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct()

View file

@ -1,79 +0,0 @@
# Generated by Django 4.2.5 on 2023-09-27 17:52
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import uuid
class Migration(migrations.Migration):
dependencies = [
("database", "0002_googleuser"),
]
operations = [
migrations.CreateModel(
name="Configuration",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
],
),
migrations.CreateModel(
name="ConversationProcessorConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("conversation", models.JSONField()),
("enable_offline_chat", models.BooleanField(default=False)),
],
),
migrations.CreateModel(
name="GithubConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("pat_token", models.CharField(max_length=200)),
("compressed_jsonl", models.CharField(max_length=300)),
("embeddings_file", models.CharField(max_length=300)),
(
"config",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"),
),
],
),
migrations.AddField(
model_name="khojuser",
name="uuid",
field=models.UUIDField(verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)),
preserve_default=False,
),
migrations.CreateModel(
name="NotionConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("token", models.CharField(max_length=200)),
("compressed_jsonl", models.CharField(max_length=300)),
("embeddings_file", models.CharField(max_length=300)),
(
"config",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"),
),
],
),
migrations.CreateModel(
name="GithubRepoConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("name", models.CharField(max_length=200)),
("owner", models.CharField(max_length=200)),
("branch", models.CharField(max_length=200)),
(
"github_config",
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.githubconfig"),
),
],
),
migrations.AddField(
model_name="configuration",
name="user",
field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
]

View file

@ -0,0 +1,10 @@
from django.db import migrations
from pgvector.django import VectorExtension
class Migration(migrations.Migration):
dependencies = [
("database", "0002_googleuser"),
]
operations = [VectorExtension()]

View file

@ -0,0 +1,193 @@
# Generated by Django 4.2.5 on 2023-10-11 22:24
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import pgvector.django
import uuid
class Migration(migrations.Migration):
dependencies = [
("database", "0003_vector_extension"),
]
operations = [
migrations.CreateModel(
name="ConversationProcessorConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("conversation", models.JSONField()),
("enable_offline_chat", models.BooleanField(default=False)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="GithubConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("pat_token", models.CharField(max_length=200)),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="khojuser",
name="uuid",
field=models.UUIDField(default=1234, verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)),
preserve_default=False,
),
migrations.CreateModel(
name="NotionConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("token", models.CharField(max_length=200)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="LocalPlaintextConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("input_files", models.JSONField(default=list, null=True)),
("input_filter", models.JSONField(default=list, null=True)),
("index_heading_entries", models.BooleanField(default=False)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="LocalPdfConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("input_files", models.JSONField(default=list, null=True)),
("input_filter", models.JSONField(default=list, null=True)),
("index_heading_entries", models.BooleanField(default=False)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="LocalOrgConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("input_files", models.JSONField(default=list, null=True)),
("input_filter", models.JSONField(default=list, null=True)),
("index_heading_entries", models.BooleanField(default=False)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="LocalMarkdownConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("input_files", models.JSONField(default=list, null=True)),
("input_filter", models.JSONField(default=list, null=True)),
("index_heading_entries", models.BooleanField(default=False)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="GithubRepoConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(max_length=200)),
("owner", models.CharField(max_length=200)),
("branch", models.CharField(max_length=200)),
(
"github_config",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="githubrepoconfig",
to="database.githubconfig",
),
),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="githubconfig",
name="user",
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
migrations.CreateModel(
name="Embeddings",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("embeddings", pgvector.django.VectorField(dimensions=384)),
("raw", models.TextField()),
("compiled", models.TextField()),
("heading", models.CharField(blank=True, default=None, max_length=1000, null=True)),
(
"file_type",
models.CharField(
choices=[
("image", "Image"),
("pdf", "Pdf"),
("plaintext", "Plaintext"),
("markdown", "Markdown"),
("org", "Org"),
("notion", "Notion"),
("github", "Github"),
("conversation", "Conversation"),
],
default="plaintext",
max_length=30,
),
),
("file_path", models.CharField(blank=True, default=None, max_length=400, null=True)),
("file_name", models.CharField(blank=True, default=None, max_length=400, null=True)),
("url", models.URLField(blank=True, default=None, max_length=400, null=True)),
("hashed_value", models.CharField(max_length=100)),
(
"user",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
},
),
]

View file

@ -0,0 +1,18 @@
# Generated by Django 4.2.5 on 2023-10-13 02:39
from django.db import migrations, models
import uuid
class Migration(migrations.Migration):
dependencies = [
("database", "0004_conversationprocessorconfig_githubconfig_and_more"),
]
operations = [
migrations.AddField(
model_name="embeddings",
name="corpus_id",
field=models.UUIDField(default=uuid.uuid4, editable=False),
),
]

View file

@ -0,0 +1,33 @@
# Generated by Django 4.2.5 on 2023-10-13 19:28
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0005_embeddings_corpus_id"),
]
operations = [
migrations.CreateModel(
name="EmbeddingsDates",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("date", models.DateField()),
(
"embeddings",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="embeddings_dates",
to="database.embeddings",
),
),
],
options={
"indexes": [models.Index(fields=["date"], name="database_em_date_a1ba47_idx")],
},
),
]

View file

@ -2,11 +2,25 @@ import uuid
from django.db import models
from django.contrib.auth.models import AbstractUser
from pgvector.django import VectorField
class BaseModel(models.Model):
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
abstract = True
class KhojUser(AbstractUser):
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
def save(self, *args, **kwargs):
if not self.uuid:
self.uuid = uuid.uuid4()
super().save(*args, **kwargs)
class GoogleUser(models.Model):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
@ -23,31 +37,85 @@ class GoogleUser(models.Model):
return self.name
class Configuration(models.Model):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
class NotionConfig(models.Model):
class NotionConfig(BaseModel):
token = models.CharField(max_length=200)
compressed_jsonl = models.CharField(max_length=300)
embeddings_file = models.CharField(max_length=300)
config = models.OneToOneField(Configuration, on_delete=models.CASCADE)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class GithubConfig(models.Model):
class GithubConfig(BaseModel):
pat_token = models.CharField(max_length=200)
compressed_jsonl = models.CharField(max_length=300)
embeddings_file = models.CharField(max_length=300)
config = models.OneToOneField(Configuration, on_delete=models.CASCADE)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class GithubRepoConfig(models.Model):
class GithubRepoConfig(BaseModel):
name = models.CharField(max_length=200)
owner = models.CharField(max_length=200)
branch = models.CharField(max_length=200)
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE)
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
class ConversationProcessorConfig(models.Model):
class LocalOrgConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalMarkdownConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalPdfConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalPlaintextConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class ConversationProcessorConfig(BaseModel):
conversation = models.JSONField()
enable_offline_chat = models.BooleanField(default=False)
class Embeddings(BaseModel):
class EmbeddingsType(models.TextChoices):
IMAGE = "image"
PDF = "pdf"
PLAINTEXT = "plaintext"
MARKDOWN = "markdown"
ORG = "org"
NOTION = "notion"
GITHUB = "github"
CONVERSATION = "conversation"
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
embeddings = VectorField(dimensions=384)
raw = models.TextField()
compiled = models.TextField()
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
file_type = models.CharField(max_length=30, choices=EmbeddingsType.choices, default=EmbeddingsType.PLAINTEXT)
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
url = models.URLField(max_length=400, default=None, null=True, blank=True)
hashed_value = models.CharField(max_length=100)
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
class EmbeddingsDates(BaseModel):
date = models.DateField()
embeddings = models.ForeignKey(Embeddings, on_delete=models.CASCADE, related_name="embeddings_dates")
class Meta:
indexes = [
models.Index(fields=["date"]),
]

3
src/database/tests.py Normal file
View file

@ -0,0 +1,3 @@
from django.test import TestCase
# Create your tests here.

View file

@ -30,6 +30,8 @@ from khoj.utils.helpers import resolve_absolute_path, merge_dicts
from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
from khoj.routers.indexer import configure_content, load_content, configure_search
from database.models import KhojUser
from database.adapters import get_all_users
logger = logging.getLogger(__name__)
@ -48,14 +50,28 @@ class UserAuthenticationBackend(AuthenticationBackend):
from database.models import KhojUser
self.khojuser_manager = KhojUser.objects
self._initialize_default_user()
super().__init__()
def _initialize_default_user(self):
if not self.khojuser_manager.filter(username="default").exists():
self.khojuser_manager.create_user(
username="default",
email="default@example.com",
password="default",
)
async def authenticate(self, request):
current_user = request.session.get("user")
if current_user and current_user.get("email"):
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
elif not state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").afirst()
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
return AuthCredentials(), UnauthenticatedUser()
@ -78,7 +94,11 @@ def initialize_server(config: Optional[FullConfig]):
def configure_server(
config: FullConfig, regenerate: bool = False, search_type: Optional[SearchType] = None, init=False
config: FullConfig,
regenerate: bool = False,
search_type: Optional[SearchType] = None,
init=False,
user: KhojUser = None,
):
# Update Config
state.config = config
@ -95,7 +115,7 @@ def configure_server(
state.config_lock.acquire()
state.SearchType = configure_search_types(state.config)
state.search_models = configure_search(state.search_models, state.config.search_type)
initialize_content(regenerate, search_type, init)
initialize_content(regenerate, search_type, init, user)
except Exception as e:
logger.error(f"🚨 Failed to configure search models", exc_info=True)
raise e
@ -103,7 +123,7 @@ def configure_server(
state.config_lock.release()
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False):
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
# Initialize Content from Config
if state.search_models:
try:
@ -112,7 +132,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
state.content_index = load_content(state.config.content_type, state.content_index, state.search_models)
else:
logger.info("📬 Updating content index...")
all_files = collect_files(state.config.content_type)
all_files = collect_files(user=user)
state.content_index = configure_content(
state.content_index,
state.config.content_type,
@ -120,6 +140,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
state.search_models,
regenerate,
search_type,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to index content", exc_info=True)
@ -152,9 +173,14 @@ if not state.demo:
def update_search_index():
try:
logger.info("📬 Updating content index via Scheduler")
all_files = collect_files(state.config.content_type)
for user in get_all_users():
all_files = collect_files(user=user)
state.content_index = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models
state.content_index, state.config.content_type, all_files, state.search_models, user=user
)
all_files = collect_files(user=None)
state.content_index = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models, user=None
)
logger.info("📪 Content index updated via Scheduler")
except Exception as e:
@ -164,13 +190,9 @@ if not state.demo:
def configure_search_types(config: FullConfig):
# Extract core search types
core_search_types = {e.name: e.value for e in SearchType}
# Extract configured plugin search types
plugin_search_types = {}
if config.content_type and config.content_type.plugins:
plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()}
# Dynamically generate search type enum by merging core search types with configured plugin search types
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
return Enum("SearchType", merge_dicts(core_search_types, {}))
def configure_processor(

View file

@ -10,13 +10,11 @@
<img class="card-icon" src="/static/assets/icons/github.svg" alt="Github">
<h3 class="card-title">
Github
{% if current_config.content_type.github %}
{% if current_model_state.github == False %}
<img id="misconfigured-icon-github" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
{% else %}
<img id="configured-icon-github" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% endif %}
{% endif %}
</h3>
</div>
<div class="card-description-row">
@ -24,7 +22,7 @@
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content_type/github">
{% if current_config.content_type.github %}
{% if current_model_state.github %}
Update
{% else %}
Setup
@ -32,7 +30,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
{% if current_config.content_type.github %}
{% if current_model_state.github %}
<div id="clear-github" class="card-action-row">
<button class="card-button" onclick="clearContentType('github')">
Disable
@ -45,13 +43,11 @@
<img class="card-icon" src="/static/assets/icons/notion.svg" alt="Notion">
<h3 class="card-title">
Notion
{% if current_config.content_type.notion %}
{% if current_model_state.notion == False %}
<img id="misconfigured-icon-notion" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
{% else %}
<img id="configured-icon-notion" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% endif %}
{% endif %}
</h3>
</div>
<div class="card-description-row">
@ -59,7 +55,7 @@
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content_type/notion">
{% if current_config.content_type.content %}
{% if current_model_state.content %}
Update
{% else %}
Setup
@ -67,7 +63,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
{% if current_config.content_type.notion %}
{% if current_model_state.notion %}
<div id="clear-notion" class="card-action-row">
<button class="card-button" onclick="clearContentType('notion')">
Disable
@ -80,7 +76,7 @@
<img class="card-icon" src="/static/assets/icons/markdown.svg" alt="markdown">
<h3 class="card-title">
Markdown
{% if current_config.content_type.markdown %}
{% if current_model_state.markdown %}
{% if current_model_state.markdown == False%}
<img id="misconfigured-icon-markdown" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
{% else %}
@ -94,7 +90,7 @@
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content_type/markdown">
{% if current_config.content_type.markdown %}
{% if current_model_state.markdown %}
Update
{% else %}
Setup
@ -102,7 +98,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
{% if current_config.content_type.markdown %}
{% if current_model_state.markdown %}
<div id="clear-markdown" class="card-action-row">
<button class="card-button" onclick="clearContentType('markdown')">
Disable
@ -115,7 +111,7 @@
<img class="card-icon" src="/static/assets/icons/org.svg" alt="org">
<h3 class="card-title">
Org
{% if current_config.content_type.org %}
{% if current_model_state.org %}
{% if current_model_state.org == False %}
<img id="misconfigured-icon-org" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
{% else %}
@ -129,7 +125,7 @@
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content_type/org">
{% if current_config.content_type.org %}
{% if current_model_state.org %}
Update
{% else %}
Setup
@ -137,7 +133,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
{% if current_config.content_type.org %}
{% if current_model_state.org %}
<div id="clear-org" class="card-action-row">
<button class="card-button" onclick="clearContentType('org')">
Disable
@ -150,7 +146,7 @@
<img class="card-icon" src="/static/assets/icons/pdf.svg" alt="PDF">
<h3 class="card-title">
PDF
{% if current_config.content_type.pdf %}
{% if current_model_state.pdf %}
{% if current_model_state.pdf == False %}
<img id="misconfigured-icon-pdf" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you need to click Configure.">
{% else %}
@ -164,7 +160,7 @@
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content_type/pdf">
{% if current_config.content_type.pdf %}
{% if current_model_state.pdf %}
Update
{% else %}
Setup
@ -172,7 +168,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
{% if current_config.content_type.pdf %}
{% if current_model_state.pdf %}
<div id="clear-pdf" class="card-action-row">
<button class="card-button" onclick="clearContentType('pdf')">
Disable
@ -185,7 +181,7 @@
<img class="card-icon" src="/static/assets/icons/plaintext.svg" alt="Plaintext">
<h3 class="card-title">
Plaintext
{% if current_config.content_type.plaintext %}
{% if current_model_state.plaintext %}
{% if current_model_state.plaintext == False %}
<img id="misconfigured-icon-plaintext" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you need to click Configure.">
{% else %}
@ -199,7 +195,7 @@
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content_type/plaintext">
{% if current_config.content_type.plaintext %}
{% if current_model_state.plaintext %}
Update
{% else %}
Setup
@ -207,7 +203,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
{% if current_config.content_type.plaintext %}
{% if current_model_state.plaintext %}
<div id="clear-plaintext" class="card-action-row">
<button class="card-button" onclick="clearContentType('plaintext')">
Disable

View file

@ -38,24 +38,6 @@
{% endfor %}
</div>
<button type="button" id="add-repository-button">Add Repository</button>
<table style="display: none;" >
<tr>
<td>
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
</td>
<td>
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
</td>
</tr>
<tr>
<td>
<label for="embeddings-file">Embeddings File (Output)</label>
</td>
<td>
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
</td>
</tr>
</table>
<div class="section">
<div id="success" style="display: none;"></div>
<button id="submit" type="submit">Save</button>
@ -107,8 +89,6 @@
submit.addEventListener("click", function(event) {
event.preventDefault();
const compressed_jsonl = document.getElementById("compressed-jsonl").value;
const embeddings_file = document.getElementById("embeddings-file").value;
const pat_token = document.getElementById("pat-token").value;
if (pat_token == "") {
@ -154,8 +134,6 @@
body: JSON.stringify({
"pat_token": pat_token,
"repos": repos,
"compressed_jsonl": compressed_jsonl,
"embeddings_file": embeddings_file,
})
})
.then(response => response.json())

View file

@ -43,33 +43,6 @@
</td>
</tr>
</table>
<table style="display: none;" >
<tr>
<td>
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
</td>
<td>
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
</td>
</tr>
<tr>
<td>
<label for="embeddings-file">Embeddings File (Output)</label>
</td>
<td>
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
</td>
</tr>
<tr>
<td>
<label for="index-heading-entries">Index Heading Entries</label>
</td>
<td>
<input type="text" id="index-heading-entries" name="index-heading-entries" value="{{ current_config['index_heading_entries'] }}">
</td>
</tr>
</table>
<div class="section">
<div id="success" style="display: none;" ></div>
<button id="submit" type="submit">Save</button>
@ -155,9 +128,8 @@
inputFilter = null;
}
var compressed_jsonl = document.getElementById("compressed-jsonl").value;
var embeddings_file = document.getElementById("embeddings-file").value;
var index_heading_entries = document.getElementById("index-heading-entries").value;
// var index_heading_entries = document.getElementById("index-heading-entries").value;
var index_heading_entries = true;
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/content_type/{{ content_type }}', {
@ -169,8 +141,6 @@
body: JSON.stringify({
"input_files": inputFiles,
"input_filter": inputFilter,
"compressed_jsonl": compressed_jsonl,
"embeddings_file": embeddings_file,
"index_heading_entries": index_heading_entries
})
})

View file

@ -20,24 +20,6 @@
</td>
</tr>
</table>
<table style="display: none;" >
<tr>
<td>
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
</td>
<td>
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
</td>
</tr>
<tr>
<td>
<label for="embeddings-file">Embeddings File (Output)</label>
</td>
<td>
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
</td>
</tr>
</table>
<div class="section">
<div id="success" style="display: none;"></div>
<button id="submit" type="submit">Save</button>
@ -51,8 +33,6 @@
submit.addEventListener("click", function(event) {
event.preventDefault();
const compressed_jsonl = document.getElementById("compressed-jsonl").value;
const embeddings_file = document.getElementById("embeddings-file").value;
const token = document.getElementById("token").value;
if (token == "") {
@ -70,8 +50,6 @@
},
body: JSON.stringify({
"token": token,
"compressed_jsonl": compressed_jsonl,
"embeddings_file": embeddings_file,
})
})
.then(response => response.json())

View file

@ -172,7 +172,7 @@
url = createRequestUrl(query, type, results_count || 5, rerank);
fetch(url, {
headers: {
"X-CSRFToken": csrfToken
"Content-Type": "application/json"
}
})
.then(response => response.json())
@ -199,8 +199,8 @@
fetch("/api/config/types")
.then(response => response.json())
.then(enabled_types => {
// Show warning if no content types are enabled
if (enabled_types.detail) {
// Show warning if no content types are enabled, or just one ("all")
if (enabled_types[0] === "all" && enabled_types.length === 1) {
document.getElementById("results").innerHTML = "<div id='results-error'>To use Khoj search, setup your content plugins on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>.</div>";
document.getElementById("query").setAttribute("disabled", "disabled");
document.getElementById("query").setAttribute("placeholder", "Configure Khoj to enable search");

View file

@ -24,10 +24,15 @@ from rich.logging import RichHandler
from django.core.asgi import get_asgi_application
from django.core.management import call_command
# Internal Packages
from khoj.configure import configure_routes, initialize_server, configure_middleware
from khoj.utils import state
from khoj.utils.cli import cli
# Initialize Django
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
django.setup()
# Initialize Django Database
call_command("migrate", "--noinput")
# Initialize Django Static Files
call_command("collectstatic", "--noinput")
# Initialize Django
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
@ -54,6 +59,11 @@ app.add_middleware(
# Set Locale
locale.setlocale(locale.LC_ALL, "")
# Internal Packages. We do this after setting up Django so that Django features are accessible to the app.
from khoj.configure import configure_routes, initialize_server, configure_middleware
from khoj.utils import state
from khoj.utils.cli import cli
# Setup Logger
rich_handler = RichHandler(rich_tracebacks=True)
rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
@ -95,6 +105,8 @@ def run():
# Mount Django and Static Files
app.mount("/django", django_app, name="django")
if not os.path.exists("static"):
os.mkdir("static")
app.mount("/static", StaticFiles(directory="static"), name="static")
# Configure Middleware
@ -111,6 +123,7 @@ def set_state(args):
state.host = args.host
state.port = args.port
state.demo = args.demo
state.anonymous_mode = args.anonymous_mode
state.khoj_version = version("khoj-assistant")

View file

@ -0,0 +1,57 @@
from typing import List
import torch
from langchain.embeddings import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel:
def __init__(self):
self.model_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
encode_kwargs = {"normalize_embeddings": True}
# encode_kwargs = {}
if torch.cuda.is_available():
# Use CUDA GPU
device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
# Use Apple M1 Metal Acceleration
device = torch.device("mps")
else:
device = torch.device("cpu")
model_kwargs = {"device": device}
self.embeddings_model = HuggingFaceEmbeddings(
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
)
def embed_query(self, query):
return self.embeddings_model.embed_query(query)
def embed_documents(self, docs):
return self.embeddings_model.embed_documents(docs)
class CrossEncoderModel:
def __init__(self):
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
if torch.cuda.is_available():
# Use CUDA GPU
device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
# Use Apple M1 Metal Acceleration
device = torch.device("mps")
else:
device = torch.device("cpu")
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=device)
def predict(self, query, hits: List[SearchResponse]):
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
cross_scores = self.cross_encoder_model.predict(cross__inp)
return cross_scores

View file

@ -2,7 +2,7 @@
import logging
import time
from datetime import datetime
from typing import Dict, List, Union
from typing import Dict, List, Union, Tuple
# External Packages
import requests
@ -12,18 +12,31 @@ from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.jsonl import compress_jsonl_data
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.rawconfig import Entry
from database.models import Embeddings, GithubConfig, KhojUser
logger = logging.getLogger(__name__)
class GithubToJsonl(TextToJsonl):
def __init__(self, config: GithubContentConfig):
class GithubToJsonl(TextEmbeddings):
def __init__(self, config: GithubConfig):
super().__init__(config)
self.config = config
raw_repos = config.githubrepoconfig.all()
repos = []
for repo in raw_repos:
repos.append(
GithubRepoConfig(
name=repo.name,
owner=repo.owner,
branch=repo.branch,
)
)
self.config = GithubContentConfig(
pat_token=config.pat_token,
repos=repos,
)
self.session = requests.Session()
self.session.headers.update({"Authorization": f"token {self.config.pat_token}"})
@ -37,7 +50,9 @@ class GithubToJsonl(TextToJsonl):
else:
return
def process(self, previous_entries=[], files=None, full_corpus=True):
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
if self.config.pat_token is None or self.config.pat_token == "":
logger.error(f"Github PAT token is not set. Skipping github content")
raise ValueError("Github PAT token is not set. Skipping github content")
@ -45,7 +60,7 @@ class GithubToJsonl(TextToJsonl):
for repo in self.config.repos:
current_entries += self.process_repo(repo)
return self.update_entries_with_ids(current_entries, previous_entries)
return self.update_entries_with_ids(current_entries, user=user)
def process_repo(self, repo: GithubRepoConfig):
repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}"
@ -80,26 +95,18 @@ class GithubToJsonl(TextToJsonl):
current_entries += issue_entries
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
current_entries = TextToJsonl.split_entries_by_max_tokens(current_entries, max_tokens=256)
current_entries = TextEmbeddings.split_entries_by_max_tokens(current_entries, max_tokens=256)
return current_entries
def update_entries_with_ids(self, current_entries, previous_entries):
def update_entries_with_ids(self, current_entries, user: KhojUser = None):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, Embeddings.EmbeddingsType.GITHUB, key="compiled", logger=logger, user=user
)
with timer("Write github entries to JSONL file", logger):
# Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
return entries_with_ids
return num_new_embeddings, num_deleted_embeddings
def get_files(self, repo_url: str, repo: GithubRepoConfig):
# Get the contents of the repository

View file

@ -1,91 +0,0 @@
# Standard Packages
import glob
import logging
from pathlib import Path
from typing import List
# Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.helpers import get_absolute_path, timer
from khoj.utils.jsonl import load_jsonl, compress_jsonl_data
from khoj.utils.rawconfig import Entry
logger = logging.getLogger(__name__)
class JsonlToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=[], files: dict[str, str] = {}, full_corpus: bool = True):
# Extract required fields from config
input_jsonl_files, input_jsonl_filter, output_file = (
self.config.input_files,
self.config.input_filter,
self.config.compressed_jsonl,
)
# Get Jsonl Input Files to Process
all_input_jsonl_files = JsonlToJsonl.get_jsonl_files(input_jsonl_files, input_jsonl_filter)
# Extract Entries from specified jsonl files
with timer("Parse entries from jsonl files", logger):
input_jsons = JsonlToJsonl.extract_jsonl_entries(all_input_jsonl_files)
current_entries = list(map(Entry.from_dict, input_jsons))
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
)
with timer("Write entries to JSONL file", logger):
# Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = JsonlToJsonl.convert_entries_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(jsonl_data, output_file)
return entries_with_ids
@staticmethod
def get_jsonl_files(jsonl_files=None, jsonl_file_filters=None):
"Get all jsonl files to process"
absolute_jsonl_files, filtered_jsonl_files = set(), set()
if jsonl_files:
absolute_jsonl_files = {get_absolute_path(jsonl_file) for jsonl_file in jsonl_files}
if jsonl_file_filters:
filtered_jsonl_files = {
filtered_file
for jsonl_file_filter in jsonl_file_filters
for filtered_file in glob.glob(get_absolute_path(jsonl_file_filter), recursive=True)
}
all_jsonl_files = sorted(absolute_jsonl_files | filtered_jsonl_files)
files_with_non_jsonl_extensions = {
jsonl_file for jsonl_file in all_jsonl_files if not jsonl_file.endswith(".jsonl")
}
if any(files_with_non_jsonl_extensions):
print(f"[Warning] There maybe non jsonl files in the input set: {files_with_non_jsonl_extensions}")
logger.debug(f"Processing files: {all_jsonl_files}")
return all_jsonl_files
@staticmethod
def extract_jsonl_entries(jsonl_files):
"Extract entries from specified jsonl files"
entries = []
for jsonl_file in jsonl_files:
entries.extend(load_jsonl(Path(jsonl_file)))
return entries
@staticmethod
def convert_entries_to_jsonl(entries: List[Entry]):
"Convert each entry to JSON and collate as JSONL"
return "".join([f"{entry.to_json()}\n" for entry in entries])

View file

@ -3,29 +3,28 @@ import logging
import re
import urllib3
from pathlib import Path
from typing import List
from typing import Tuple, List
# Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry, TextContentConfig
from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser
logger = logging.getLogger(__name__)
class MarkdownToJsonl(TextToJsonl):
def __init__(self, config: TextContentConfig):
super().__init__(config)
self.config = config
class MarkdownToJsonl(TextEmbeddings):
def __init__(self):
super().__init__()
# Define Functions
def process(self, previous_entries=[], files=None, full_corpus: bool = True):
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
output_file = self.config.compressed_jsonl
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
@ -45,19 +44,17 @@ class MarkdownToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
Embeddings.EmbeddingsType.MARKDOWN,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
)
with timer("Write markdown entries to JSONL file", logger):
# Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(jsonl_data, output_file)
return entries_with_ids
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_markdown_entries(markdown_files):

View file

@ -1,5 +1,6 @@
# Standard Packages
import logging
from typing import Tuple
# External Packages
import requests
@ -7,9 +8,9 @@ import requests
# Internal Packages
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, NotionContentConfig
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.utils.jsonl import compress_jsonl_data
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser, NotionConfig
from enum import Enum
@ -49,10 +50,12 @@ class NotionBlockType(Enum):
CALLOUT = "callout"
class NotionToJsonl(TextToJsonl):
def __init__(self, config: NotionContentConfig):
class NotionToJsonl(TextEmbeddings):
def __init__(self, config: NotionConfig):
super().__init__(config)
self.config = config
self.config = NotionContentConfig(
token=config.token,
)
self.session = requests.Session()
self.session.headers.update({"Authorization": f"Bearer {config.token}", "Notion-Version": "2022-02-22"})
self.unsupported_block_types = [
@ -80,7 +83,9 @@ class NotionToJsonl(TextToJsonl):
self.body_params = {"page_size": 100}
def process(self, previous_entries=[], files=None, full_corpus=True):
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
current_entries = []
# Get all pages
@ -112,7 +117,7 @@ class NotionToJsonl(TextToJsonl):
page_entries = self.process_page(p_or_d)
current_entries.extend(page_entries)
return self.update_entries_with_ids(current_entries, previous_entries)
return self.update_entries_with_ids(current_entries, user)
def process_page(self, page):
page_id = page["id"]
@ -241,19 +246,11 @@ class NotionToJsonl(TextToJsonl):
title = None
return title, content
def update_entries_with_ids(self, current_entries, previous_entries):
def update_entries_with_ids(self, current_entries, user: KhojUser = None):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries, Embeddings.EmbeddingsType.NOTION, key="compiled", logger=logger, user=user
)
with timer("Write Notion entries to JSONL file", logger):
# Process Each Entry from all Notion entries
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
return entries_with_ids
return num_new_embeddings, num_deleted_embeddings

View file

@ -5,28 +5,26 @@ from typing import Iterable, List, Tuple
# Internal Packages
from khoj.processor.org_mode import orgnode
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry, TextContentConfig
from khoj.utils.rawconfig import Entry
from khoj.utils import state
from database.models import Embeddings, KhojUser
logger = logging.getLogger(__name__)
class OrgToJsonl(TextToJsonl):
def __init__(self, config: TextContentConfig):
super().__init__(config)
self.config = config
class OrgToJsonl(TextEmbeddings):
def __init__(self):
super().__init__()
# Define Functions
def process(
self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True
) -> List[Tuple[int, Entry]]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
output_file = self.config.compressed_jsonl
index_heading_entries = self.config.index_heading_entries
index_heading_entries = True
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
@ -47,19 +45,17 @@ class OrgToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
Embeddings.EmbeddingsType.ORG,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
)
# Process Each Entry from All Notes Files
with timer("Write org entries to JSONL file", logger):
entries = map(lambda entry: entry[1], entries_with_ids)
jsonl_data = self.convert_org_entries_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(jsonl_data, output_file)
return entries_with_ids
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_org_entries(org_files: dict[str, str]):

View file

@ -1,28 +1,31 @@
# Standard Packages
import os
import logging
from typing import List
from typing import List, Tuple
import base64
# External Packages
from langchain.document_loaders import PyMuPDFLoader
# Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry
from database.models import Embeddings, KhojUser
logger = logging.getLogger(__name__)
class PdfToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries=[], files: dict[str, str] = None, full_corpus: bool = True):
# Extract required fields from config
output_file = self.config.compressed_jsonl
class PdfToJsonl(TextEmbeddings):
def __init__(self):
super().__init__()
# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
@ -40,19 +43,17 @@ class PdfToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
Embeddings.EmbeddingsType.PDF,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
)
with timer("Write PDF entries to JSONL file", logger):
# Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = PdfToJsonl.convert_pdf_maps_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(jsonl_data, output_file)
return entries_with_ids
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_pdf_entries(pdf_files):
@ -62,7 +63,7 @@ class PdfToJsonl(TextToJsonl):
entry_to_location_map = []
for pdf_file in pdf_files:
try:
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PyPDFLoader expects a file path
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
tmp_file = f"tmp_pdf_file.pdf"
with open(f"{tmp_file}", "wb") as f:
bytes = pdf_files[pdf_file]

View file

@ -4,22 +4,23 @@ from pathlib import Path
from typing import List, Tuple
# Internal Packages
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry
from khoj.utils.rawconfig import Entry, TextContentConfig
from database.models import Embeddings, KhojUser, LocalPlaintextConfig
logger = logging.getLogger(__name__)
class PlaintextToJsonl(TextToJsonl):
class PlaintextToJsonl(TextEmbeddings):
def __init__(self):
super().__init__()
# Define Functions
def process(
self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True
) -> List[Tuple[int, Entry]]:
output_file = self.config.compressed_jsonl
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
@ -37,19 +38,17 @@ class PlaintextToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
entries_with_ids = TextToJsonl.mark_entries_for_update(
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
Embeddings.EmbeddingsType.PLAINTEXT,
key="compiled",
logger=logger,
deletion_filenames=deletion_file_names,
user=user,
regenerate=regenerate,
)
with timer("Write entries to JSONL file", logger):
# Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids))
plaintext_data = PlaintextToJsonl.convert_entries_to_jsonl(entries)
# Compress JSONL formatted Data
compress_jsonl_data(plaintext_data, output_file)
return entries_with_ids
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:

View file

@ -2,24 +2,33 @@
from abc import ABC, abstractmethod
import hashlib
import logging
from typing import Callable, List, Tuple, Set
import uuid
from tqdm import tqdm
from typing import Callable, List, Tuple, Set, Any
from khoj.utils.helpers import timer
# Internal Packages
from khoj.utils.rawconfig import Entry, TextConfigBase
from khoj.utils.rawconfig import Entry
from khoj.processor.embeddings import EmbeddingsModel
from khoj.search_filter.date_filter import DateFilter
from database.models import KhojUser, Embeddings, EmbeddingsDates
from database.adapters import EmbeddingsAdapters
logger = logging.getLogger(__name__)
class TextToJsonl(ABC):
def __init__(self, config: TextConfigBase):
class TextEmbeddings(ABC):
def __init__(self, config: Any = None):
self.embeddings_model = EmbeddingsModel()
self.config = config
self.date_filter = DateFilter()
@abstractmethod
def process(
self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True
) -> List[Tuple[int, Entry]]:
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
...
@staticmethod
@ -38,6 +47,7 @@ class TextToJsonl(ABC):
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
corpus_id = uuid.uuid4()
# Split entry into chunks of max tokens
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
@ -57,11 +67,103 @@ class TextToJsonl(ABC):
raw=entry.raw,
heading=entry.heading,
file=entry.file,
corpus_id=corpus_id,
)
)
return chunked_entries
def update_embeddings(
self,
current_entries: List[Entry],
file_type: str,
key="compiled",
logger: logging.Logger = None,
deletion_filenames: Set[str] = None,
user: KhojUser = None,
regenerate: bool = False,
):
with timer("Construct current entry hashes", logger):
hashes_by_file = dict[str, set[str]]()
current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries))
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
for entry in tqdm(current_entries, desc="Hashing Entries"):
hashes_by_file.setdefault(entry.file, set()).add(TextEmbeddings.hash_func(key)(entry))
num_deleted_embeddings = 0
with timer("Preparing dataset for regeneration", logger):
if regenerate:
logger.info(f"Deleting all embeddings for file type {file_type}")
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
num_new_embeddings = 0
with timer("Identify hashes for adding new entries", logger):
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
hashes_for_file = hashes_by_file[file]
hashes_to_process = set()
existing_entries = Embeddings.objects.filter(
user=user, hashed_value__in=hashes_for_file, file_type=file_type
)
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
hashes_to_process = hashes_for_file - existing_entry_hashes
# for hashed_val in hashes_for_file:
# if not EmbeddingsAdapters.does_embedding_exist(user, hashed_val):
# hashes_to_process.add(hashed_val)
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings = self.embeddings_model.embed_documents(data_to_embed)
with timer("Update the database with new vector embeddings", logger):
embeddings_to_create = []
for hashed_val, embedding in zip(hashes_to_process, embeddings):
entry = hash_to_current_entries[hashed_val]
embeddings_to_create.append(
Embeddings(
user=user,
embeddings=embedding,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading,
file_path=entry.file,
file_type=file_type,
hashed_value=hashed_val,
corpus_id=entry.corpus_id,
)
)
new_embeddings = Embeddings.objects.bulk_create(embeddings_to_create)
num_new_embeddings += len(new_embeddings)
dates_to_create = []
with timer("Create new date associations for new embeddings", logger):
for embedding in new_embeddings:
dates = self.date_filter.extract_dates(embedding.raw)
for date in dates:
dates_to_create.append(
EmbeddingsDates(
date=date,
embeddings=embedding,
)
)
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
if len(new_dates) > 0:
logger.info(f"Created {len(new_dates)} new date entries")
with timer("Identify hashes for removed entries", logger):
for file in hashes_by_file:
existing_entry_hashes = EmbeddingsAdapters.get_existing_entry_hashes_by_file(user, file)
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
num_deleted_embeddings += len(to_delete_entry_hashes)
EmbeddingsAdapters.delete_embedding_by_hash(user, hashed_values=list(to_delete_entry_hashes))
with timer("Identify hashes for deleting entries", logger):
if deletion_filenames is not None:
for file_path in deletion_filenames:
deleted_count = EmbeddingsAdapters.delete_embedding_by_file(user, file_path)
num_deleted_embeddings += deleted_count
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def mark_entries_for_update(
current_entries: List[Entry],
@ -72,11 +174,11 @@ class TextToJsonl(ABC):
):
# Hash all current and previous entries to identify new entries
with timer("Hash previous, current entries", logger):
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries))
current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries))
previous_entry_hashes = list(map(TextEmbeddings.hash_func(key), previous_entries))
if deletion_filenames is not None:
deletion_entries = [entry for entry in previous_entries if entry.file in deletion_filenames]
deletion_entry_hashes = list(map(TextToJsonl.hash_func(key), deletion_entries))
deletion_entry_hashes = list(map(TextEmbeddings.hash_func(key), deletion_entries))
else:
deletion_entry_hashes = []

View file

@ -2,14 +2,15 @@
import concurrent.futures
import math
import time
import yaml
import logging
import json
from typing import List, Optional, Union, Any
import asyncio
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request
from sentence_transformers import util
from fastapi import APIRouter, HTTPException, Header, Request, Depends
from starlette.authentication import requires
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.configure import configure_processor, configure_server
@ -20,7 +21,6 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.utils.config import TextSearchModel
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
from khoj.utils.rawconfig import (
ContentConfig,
FullConfig,
ProcessorConfig,
SearchConfig,
@ -48,11 +48,74 @@ from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
from fastapi.requests import Request
from database import adapters
from database.adapters import EmbeddingsAdapters
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
# Initialize Router
api = APIRouter()
logger = logging.getLogger(__name__)
def map_config_to_object(content_type: str):
if content_type == "org":
return LocalOrgConfig
if content_type == "markdown":
return LocalMarkdownConfig
if content_type == "pdf":
return LocalPdfConfig
if content_type == "plaintext":
return LocalPlaintextConfig
async def map_config_to_db(config: FullConfig, user: KhojUser):
if config.content_type:
if config.content_type.org:
await LocalOrgConfig.objects.filter(user=user).adelete()
await LocalOrgConfig.objects.acreate(
input_files=config.content_type.org.input_files,
input_filter=config.content_type.org.input_filter,
index_heading_entries=config.content_type.org.index_heading_entries,
user=user,
)
if config.content_type.markdown:
await LocalMarkdownConfig.objects.filter(user=user).adelete()
await LocalMarkdownConfig.objects.acreate(
input_files=config.content_type.markdown.input_files,
input_filter=config.content_type.markdown.input_filter,
index_heading_entries=config.content_type.markdown.index_heading_entries,
user=user,
)
if config.content_type.pdf:
await LocalPdfConfig.objects.filter(user=user).adelete()
await LocalPdfConfig.objects.acreate(
input_files=config.content_type.pdf.input_files,
input_filter=config.content_type.pdf.input_filter,
index_heading_entries=config.content_type.pdf.index_heading_entries,
user=user,
)
if config.content_type.plaintext:
await LocalPlaintextConfig.objects.filter(user=user).adelete()
await LocalPlaintextConfig.objects.acreate(
input_files=config.content_type.plaintext.input_files,
input_filter=config.content_type.plaintext.input_filter,
index_heading_entries=config.content_type.plaintext.index_heading_entries,
user=user,
)
if config.content_type.github:
await adapters.set_user_github_config(
user=user,
pat_token=config.content_type.github.pat_token,
repos=config.content_type.github.repos,
)
if config.content_type.notion:
await adapters.set_notion_config(
user=user,
token=config.content_type.notion.token,
)
# If it's a demo instance, prevent updating any of the configuration.
if not state.demo:
@ -64,7 +127,10 @@ if not state.demo:
state.processor_config = configure_processor(state.config.processor)
@api.get("/config/data", response_model=FullConfig)
def get_config_data():
def get_config_data(request: Request):
user = request.user.object if request.user.is_authenticated else None
enabled_content = EmbeddingsAdapters.get_unique_file_types(user)
return state.config
@api.post("/config/data")
@ -73,20 +139,19 @@ if not state.demo:
updated_config: FullConfig,
client: Optional[str] = None,
):
state.config = updated_config
with open(state.config_file, "w") as outfile:
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
outfile.close()
user = request.user.object if request.user.is_authenticated else None
await map_config_to_db(updated_config, user)
configuration_update_metadata = dict()
configuration_update_metadata = {}
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
if state.config.content_type is not None:
configuration_update_metadata["github"] = state.config.content_type.github is not None
configuration_update_metadata["notion"] = state.config.content_type.notion is not None
configuration_update_metadata["org"] = state.config.content_type.org is not None
configuration_update_metadata["pdf"] = state.config.content_type.pdf is not None
configuration_update_metadata["markdown"] = state.config.content_type.markdown is not None
configuration_update_metadata["plugins"] = state.config.content_type.plugins is not None
configuration_update_metadata["github"] = "github" in enabled_content
configuration_update_metadata["notion"] = "notion" in enabled_content
configuration_update_metadata["org"] = "org" in enabled_content
configuration_update_metadata["pdf"] = "pdf" in enabled_content
configuration_update_metadata["markdown"] = "markdown" in enabled_content
if state.config.processor is not None:
configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None
@ -101,6 +166,7 @@ if not state.demo:
return state.config
@api.post("/config/data/content_type/github", status_code=200)
@requires("authenticated")
async def set_content_config_github_data(
request: Request,
updated_config: Union[GithubContentConfig, None],
@ -108,10 +174,13 @@ if not state.demo:
):
_initialize_config()
if not state.config.content_type:
state.config.content_type = ContentConfig(**{"github": updated_config})
else:
state.config.content_type.github = updated_config
user = request.user.object if request.user.is_authenticated else None
await adapters.set_user_github_config(
user=user,
pat_token=updated_config.pat_token,
repos=updated_config.repos,
)
update_telemetry_state(
request=request,
@ -121,11 +190,7 @@ if not state.demo:
metadata={"content_type": "github"},
)
try:
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
@api.post("/config/data/content_type/notion", status_code=200)
async def set_content_config_notion_data(
@ -135,10 +200,12 @@ if not state.demo:
):
_initialize_config()
if not state.config.content_type:
state.config.content_type = ContentConfig(**{"notion": updated_config})
else:
state.config.content_type.notion = updated_config
user = request.user.object if request.user.is_authenticated else None
await adapters.set_notion_config(
user=user,
token=updated_config.token,
)
update_telemetry_state(
request=request,
@ -148,11 +215,7 @@ if not state.demo:
metadata={"content_type": "notion"},
)
try:
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
async def remove_content_config_data(
@ -160,8 +223,7 @@ if not state.demo:
content_type: str,
client: Optional[str] = None,
):
if not state.config or not state.config.content_type:
return {"status": "ok"}
user = request.user.object if request.user.is_authenticated else None
update_telemetry_state(
request=request,
@ -171,31 +233,13 @@ if not state.demo:
metadata={"content_type": content_type},
)
if state.config.content_type:
state.config.content_type[content_type] = None
content_object = map_config_to_object(content_type)
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
if content_type == "github":
state.content_index.github = None
elif content_type == "notion":
state.content_index.notion = None
elif content_type == "plugins":
state.content_index.plugins = None
elif content_type == "pdf":
state.content_index.pdf = None
elif content_type == "markdown":
state.content_index.markdown = None
elif content_type == "org":
state.content_index.org = None
elif content_type == "plaintext":
state.content_index.plaintext = None
else:
logger.warning(f"Request to delete unknown content type: {content_type} via API")
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
try:
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
async def remove_processor_conversation_config_data(
@ -228,6 +272,7 @@ if not state.demo:
return {"status": "error", "message": str(e)}
@api.post("/config/data/content_type/{content_type}", status_code=200)
# @requires("authenticated")
async def set_content_config_data(
request: Request,
content_type: str,
@ -236,10 +281,10 @@ if not state.demo:
):
_initialize_config()
if not state.config.content_type:
state.config.content_type = ContentConfig(**{content_type: updated_config})
else:
state.config.content_type[content_type] = updated_config
user = request.user.object if request.user.is_authenticated else None
content_object = map_config_to_object(content_type)
await adapters.set_text_content_config(user, content_object, updated_config)
update_telemetry_state(
request=request,
@ -249,11 +294,7 @@ if not state.demo:
metadata={"content_type": content_type},
)
try:
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
@api.post("/config/data/processor/conversation/openai", status_code=200)
async def set_processor_openai_config_data(
@ -337,24 +378,23 @@ def get_default_config_data():
@api.get("/config/types", response_model=List[str])
def get_config_types():
"""Get configured content types"""
if state.config is None or state.config.content_type is None:
raise HTTPException(
status_code=500,
detail="Content types not configured. Configure at least one content type on server and restart it.",
)
def get_config_types(
request: Request,
):
user = request.user.object if request.user.is_authenticated else None
enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user)
configured_content_types = list(enabled_file_types)
if state.config and state.config.content_type:
for ctype in state.config.content_type.dict(exclude_none=True):
configured_content_types.append(ctype)
configured_content_types = state.config.content_type.dict(exclude_none=True)
return [
search_type.value
for search_type in SearchType
if (
search_type.value in configured_content_types
and getattr(state.content_index, search_type.value) is not None
)
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
or search_type == SearchType.All
if (search_type.value in configured_content_types) or search_type == SearchType.All
]
@ -372,6 +412,7 @@ async def search(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
user = request.user.object if request.user.is_authenticated else None
start_time = time.time()
# Run validation checks
@ -390,10 +431,11 @@ async def search(
search_futures: List[concurrent.futures.Future] = []
# return cached results, if available
if user:
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
if query_cache_key in state.query_cache:
if query_cache_key in state.query_cache[user.uuid]:
logger.debug(f"Return response from query cache")
return state.query_cache[query_cache_key]
return state.query_cache[user.uuid][query_cache_key]
# Encode query with filter terms removed
defiltered_query = user_query
@ -407,84 +449,31 @@ async def search(
]
if text_search_models:
with timer("Encoding query took", logger=logger):
encoded_asymmetric_query = util.normalize_embeddings(
text_search_models[0].bi_encoder.encode(
[defiltered_query],
convert_to_tensor=True,
device=state.device,
)
)
encoded_asymmetric_query = state.embeddings_model.embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
# query org-mode notes
search_futures += [
executor.submit(
text_search.query,
user_query,
state.search_models.text_search,
state.content_index.org,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (
(t == SearchType.Markdown or t == SearchType.All)
and state.content_index.markdown
and state.search_models.text_search
):
if t in [
SearchType.All,
SearchType.Org,
SearchType.Markdown,
SearchType.Github,
SearchType.Notion,
SearchType.Plaintext,
]:
# query markdown notes
search_futures += [
executor.submit(
text_search.query,
user,
user_query,
state.search_models.text_search,
state.content_index.markdown,
t,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (
(t == SearchType.Github or t == SearchType.All)
and state.content_index.github
and state.search_models.text_search
):
# query github issues
search_futures += [
executor.submit(
text_search.query,
user_query,
state.search_models.text_search,
state.content_index.github,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
# query pdf files
search_futures += [
executor.submit(
text_search.query,
user_query,
state.search_models.text_search,
state.content_index.pdf,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
elif (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
# query images
search_futures += [
executor.submit(
@ -497,70 +486,6 @@ async def search(
)
]
if (
(t == SearchType.All or t in SearchType)
and state.content_index.plugins
and state.search_models.plugin_search
):
# query specified plugin type
# Get plugin content, search model for specified search type, or the first one if none specified
plugin_search = state.search_models.plugin_search.get(t.value) or next(
iter(state.search_models.plugin_search.values())
)
plugin_content = state.content_index.plugins.get(t.value) or next(
iter(state.content_index.plugins.values())
)
search_futures += [
executor.submit(
text_search.query,
user_query,
plugin_search,
plugin_content,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (
(t == SearchType.Notion or t == SearchType.All)
and state.content_index.notion
and state.search_models.text_search
):
# query notion pages
search_futures += [
executor.submit(
text_search.query,
user_query,
state.search_models.text_search,
state.content_index.notion,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
if (
(t == SearchType.Plaintext or t == SearchType.All)
and state.content_index.plaintext
and state.search_models.text_search
):
# query plaintext files
search_futures += [
executor.submit(
text_search.query,
user_query,
state.search_models.text_search,
state.content_index.plaintext,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
dedupe=dedupe or True,
)
]
# Query across each requested content types in parallel
with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures):
@ -576,15 +501,19 @@ async def search(
count=results_count,
)
else:
hits, entries = await search_future.result()
hits = await search_future.result()
# Collate results
results += text_search.collate_results(hits, entries, results_count)
results += text_search.collate_results(hits, dedupe=dedupe)
if r:
results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
else:
# Sort results across all content types and take top results
results = sorted(results, key=lambda x: float(x.score), reverse=True)[:results_count]
results = sorted(results, key=lambda x: float(x.score))[:results_count]
# Cache results
state.query_cache[query_cache_key] = results
if user:
state.query_cache[user.uuid][query_cache_key] = results
update_telemetry_state(
request=request,
@ -596,8 +525,6 @@ async def search(
host=host,
)
state.previous_query = user_query
end_time = time.time()
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
@ -614,12 +541,13 @@ def update(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
user = request.user.object if request.user.is_authenticated else None
if not state.config:
error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
logger.warning(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
try:
configure_server(state.config, regenerate=force, search_type=t)
configure_server(state.config, regenerate=force, search_type=t, user=user)
except Exception as e:
error_msg = f"🚨 Failed to update server via API: {e}"
logger.error(error_msg, exc_info=True)
@ -774,6 +702,7 @@ async def extract_references_and_questions(
n: int,
conversation_type: ConversationCommand = ConversationCommand.Default,
):
user = request.user.object if request.user.is_authenticated else None
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
@ -781,7 +710,7 @@ async def extract_references_and_questions(
compiled_references: List[Any] = []
inferred_queries: List[str] = []
if state.content_index is None:
if not EmbeddingsAdapters.user_has_embeddings(user=user):
logger.warning(
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
)

View file

@ -10,6 +10,7 @@ from khoj.utils.helpers import ConversationCommand, timer, log_telemetry
from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
from database.models import KhojUser
logger = logging.getLogger(__name__)
@ -40,11 +41,13 @@ def update_telemetry_state(
host: Optional[str] = None,
metadata: Optional[dict] = None,
):
user: KhojUser = request.user.object if request.user.is_authenticated else None
user_state = {
"client_host": request.client.host if request.client else None,
"user_agent": user_agent or "unknown",
"referer": referer or "unknown",
"host": host or "unknown",
"server_id": str(user.uuid) if user else None,
}
if metadata:

View file

@ -1,6 +1,7 @@
# Standard Packages
import logging
from typing import Optional, Union, Dict
import asyncio
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request, Response, UploadFile
@ -9,31 +10,30 @@ from khoj.routers.helpers import update_telemetry_state
# Internal Packages
from khoj.utils import state, constants
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.utils.rawconfig import ContentConfig, TextContentConfig
from khoj.search_type import text_search, image_search
from khoj.utils.yaml import save_config_to_file_updated_state
from khoj.utils.config import SearchModels
from khoj.utils.constants import default_config
from khoj.utils.helpers import LRU, get_file_type
from khoj.utils.rawconfig import (
ContentConfig,
FullConfig,
SearchConfig,
)
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.utils.config import (
ContentIndex,
SearchModels,
)
from database.models import (
KhojUser,
GithubConfig,
NotionConfig,
)
logger = logging.getLogger(__name__)
@ -68,14 +68,14 @@ async def update(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
user = request.user.object if request.user.is_authenticated else None
if x_api_key != "secret":
raise HTTPException(status_code=401, detail="Invalid API Key")
state.config_lock.acquire()
try:
logger.info(f"📬 Updating content index via API call by {client} client")
org_files: Dict[str, str] = {}
markdown_files: Dict[str, str] = {}
pdf_files: Dict[str, str] = {}
pdf_files: Dict[str, bytes] = {}
plaintext_files: Dict[str, str] = {}
for file in files:
@ -86,7 +86,7 @@ async def update(
elif file_type == "markdown":
dict_to_update = markdown_files
elif file_type == "pdf":
dict_to_update = pdf_files
dict_to_update = pdf_files # type: ignore
elif file_type == "plaintext":
dict_to_update = plaintext_files
@ -120,30 +120,31 @@ async def update(
github=None,
notion=None,
plaintext=None,
plugins=None,
)
state.config.content_type = default_content_config
save_config_to_file_updated_state()
configure_search(state.search_models, state.config.search_type)
# Extract required fields from config
state.content_index = configure_content(
loop = asyncio.get_event_loop()
state.content_index = await loop.run_in_executor(
None,
configure_content,
state.content_index,
state.config.content_type,
indexer_input.dict(),
state.search_models,
regenerate=force,
t=t,
full_corpus=False,
force,
t,
False,
user,
)
logger.info(f"Finished processing batch indexing request")
except Exception as e:
logger.error(
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
exc_info=True,
)
finally:
state.config_lock.release()
update_telemetry_state(
request=request,
@ -167,11 +168,6 @@ def configure_search(search_models: SearchModels, search_config: Optional[Search
if search_models is None:
search_models = SearchModels()
# Initialize Search Models
if search_config.asymmetric:
logger.info("🔍 📜 Setting up text search model")
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
if search_config.image:
logger.info("🔍 🌄 Setting up image search model")
search_models.image_search = image_search.initialize_model(search_config.image)
@ -187,15 +183,8 @@ def configure_content(
regenerate: bool = False,
t: Optional[Union[state.SearchType, str]] = None,
full_corpus: bool = True,
user: KhojUser = None,
) -> Optional[ContentIndex]:
def has_valid_text_config(config: TextContentConfig):
return config.input_files or config.input_filter
# Run Validation Checks
if content_config is None:
logger.warning("🚨 No Content configuration available.")
return None
if content_index is None:
content_index = ContentIndex()
if t in [type.value for type in state.SearchType]:
@ -209,59 +198,30 @@ def configure_content(
try:
# Initialize Org Notes Search
if (
(t == None or t == state.SearchType.Org.value)
and ((content_config.org and has_valid_text_config(content_config.org)) or files["org"])
and search_models.text_search
):
if content_config.org == None:
logger.info("🦄 No configuration for orgmode notes. Using default configuration.")
default_configuration = default_config["content-type"]["org"] # type: ignore
content_config.org = TextContentConfig(
compressed_jsonl=default_configuration["compressed-jsonl"],
embeddings_file=default_configuration["embeddings-file"],
)
if (t == None or t == state.SearchType.Org.value) and files["org"]:
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
content_index.org = text_search.setup(
text_search.setup(
OrgToJsonl,
files.get("org"),
content_config.org,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup org: {e}", exc_info=True)
try:
# Initialize Markdown Search
if (
(t == None or t == state.SearchType.Markdown.value)
and ((content_config.markdown and has_valid_text_config(content_config.markdown)) or files["markdown"])
and search_models.text_search
and files["markdown"]
):
if content_config.markdown == None:
logger.info("💎 No configuration for markdown notes. Using default configuration.")
default_configuration = default_config["content-type"]["markdown"] # type: ignore
content_config.markdown = TextContentConfig(
compressed_jsonl=default_configuration["compressed-jsonl"],
embeddings_file=default_configuration["embeddings-file"],
)
if (t == None or t == state.SearchType.Markdown.value) and files["markdown"]:
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
content_index.markdown = text_search.setup(
text_search.setup(
MarkdownToJsonl,
files.get("markdown"),
content_config.markdown,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
user=user,
)
except Exception as e:
@ -269,30 +229,15 @@ def configure_content(
try:
# Initialize PDF Search
if (
(t == None or t == state.SearchType.Pdf.value)
and ((content_config.pdf and has_valid_text_config(content_config.pdf)) or files["pdf"])
and search_models.text_search
and files["pdf"]
):
if content_config.pdf == None:
logger.info("🖨️ No configuration for pdf notes. Using default configuration.")
default_configuration = default_config["content-type"]["pdf"] # type: ignore
content_config.pdf = TextContentConfig(
compressed_jsonl=default_configuration["compressed-jsonl"],
embeddings_file=default_configuration["embeddings-file"],
)
if (t == None or t == state.SearchType.Pdf.value) and files["pdf"]:
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
content_index.pdf = text_search.setup(
text_search.setup(
PdfToJsonl,
files.get("pdf"),
content_config.pdf,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
user=user,
)
except Exception as e:
@ -300,30 +245,15 @@ def configure_content(
try:
# Initialize Plaintext Search
if (
(t == None or t == state.SearchType.Plaintext.value)
and ((content_config.plaintext and has_valid_text_config(content_config.plaintext)) or files["plaintext"])
and search_models.text_search
and files["plaintext"]
):
if content_config.plaintext == None:
logger.info("📄 No configuration for plaintext notes. Using default configuration.")
default_configuration = default_config["content-type"]["plaintext"] # type: ignore
content_config.plaintext = TextContentConfig(
compressed_jsonl=default_configuration["compressed-jsonl"],
embeddings_file=default_configuration["embeddings-file"],
)
if (t == None or t == state.SearchType.Plaintext.value) and files["plaintext"]:
logger.info("📄 Setting up search for plaintext")
# Extract Entries, Generate Plaintext Embeddings
content_index.plaintext = text_search.setup(
text_search.setup(
PlaintextToJsonl,
files.get("plaintext"),
content_config.plaintext,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
user=user,
)
except Exception as e:
@ -331,7 +261,12 @@ def configure_content(
try:
# Initialize Image Search
if (t == None or t == state.SearchType.Image.value) and content_config.image and search_models.image_search:
if (
(t == None or t == state.SearchType.Image.value)
and content_config
and content_config.image
and search_models.image_search
):
logger.info("🌄 Setting up search for images")
# Extract Entries, Generate Image Embeddings
content_index.image = image_search.setup(
@ -342,17 +277,17 @@ def configure_content(
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
try:
if (t == None or t == state.SearchType.Github.value) and content_config.github and search_models.text_search:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if (t == None or t == state.SearchType.Github.value) and github_config is not None:
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
content_index.github = text_search.setup(
text_search.setup(
GithubToJsonl,
None,
content_config.github,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
user=user,
config=github_config,
)
except Exception as e:
@ -360,42 +295,24 @@ def configure_content(
try:
# Initialize Notion Search
if (t == None or t in state.SearchType.Notion.value) and content_config.notion and search_models.text_search:
notion_config = NotionConfig.objects.filter(user=user).first()
if (t == None or t in state.SearchType.Notion.value) and notion_config:
logger.info("🔌 Setting up search for notion")
content_index.notion = text_search.setup(
text_search.setup(
NotionToJsonl,
None,
content_config.notion,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
user=user,
config=notion_config,
)
except Exception as e:
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
try:
# Initialize External Plugin Search
if t == None and content_config.plugins and search_models.text_search:
logger.info("🔌 Setting up search for plugins")
content_index.plugins = {}
for plugin_type, plugin_config in content_config.plugins.items():
content_index.plugins[plugin_type] = text_search.setup(
JsonlToJsonl,
None,
plugin_config,
search_models.text_search.bi_encoder,
regenerate=regenerate,
filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
)
except Exception as e:
logger.error(f"🚨 Failed to setup Plugin: {e}", exc_info=True)
# Invalidate Query Cache
state.query_cache = LRU()
if user:
state.query_cache[user.uuid] = LRU()
return content_index
@ -412,44 +329,9 @@ def load_content(
if content_index is None:
content_index = ContentIndex()
if content_config.org:
logger.info("🦄 Loading orgmode notes")
content_index.org = text_search.load(content_config.org, filters=[DateFilter(), WordFilter(), FileFilter()])
if content_config.markdown:
logger.info("💎 Loading markdown notes")
content_index.markdown = text_search.load(
content_config.markdown, filters=[DateFilter(), WordFilter(), FileFilter()]
)
if content_config.pdf:
logger.info("🖨️ Loading pdf")
content_index.pdf = text_search.load(content_config.pdf, filters=[DateFilter(), WordFilter(), FileFilter()])
if content_config.plaintext:
logger.info("📄 Loading plaintext")
content_index.plaintext = text_search.load(
content_config.plaintext, filters=[DateFilter(), WordFilter(), FileFilter()]
)
if content_config.image:
logger.info("🌄 Loading images")
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
if content_config.github:
logger.info("🐙 Loading github")
content_index.github = text_search.load(
content_config.github, filters=[DateFilter(), WordFilter(), FileFilter()]
)
if content_config.notion:
logger.info("🔌 Loading notion")
content_index.notion = text_search.load(
content_config.notion, filters=[DateFilter(), WordFilter(), FileFilter()]
)
if content_config.plugins:
logger.info("🔌 Loading plugins")
content_index.plugins = {}
for plugin_type, plugin_config in content_config.plugins.items():
content_index.plugins[plugin_type] = text_search.load(
plugin_config, filters=[DateFilter(), WordFilter(), FileFilter()]
)
state.query_cache = LRU()
return content_index

View file

@ -3,10 +3,20 @@ from fastapi import APIRouter
from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.templating import Jinja2Templates
from khoj.utils.rawconfig import TextContentConfig, OpenAIProcessorConfig, FullConfig
from starlette.authentication import requires
from khoj.utils.rawconfig import (
TextContentConfig,
OpenAIProcessorConfig,
FullConfig,
GithubContentConfig,
GithubRepoConfig,
NotionContentConfig,
)
# Internal Packages
from khoj.utils import constants, state
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
from database.models import KhojUser, LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
import json
@ -29,10 +39,23 @@ def chat_page(request: Request):
return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo})
def map_config_to_object(content_type: str):
if content_type == "org":
return LocalOrgConfig
if content_type == "markdown":
return LocalMarkdownConfig
if content_type == "pdf":
return LocalPdfConfig
if content_type == "plaintext":
return LocalPlaintextConfig
if not state.demo:
@web_client.get("/config", response_class=HTMLResponse)
def config_page(request: Request):
user = request.user.object if request.user.is_authenticated else None
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
default_full_config = FullConfig(
content_type=None,
search_type=None,
@ -41,13 +64,13 @@ if not state.demo:
current_config = state.config or json.loads(default_full_config.json())
successfully_configured = {
"pdf": False,
"markdown": False,
"org": False,
"pdf": ("pdf" in enabled_content),
"markdown": ("markdown" in enabled_content),
"org": ("org" in enabled_content),
"image": False,
"github": False,
"notion": False,
"plaintext": False,
"github": ("github" in enabled_content),
"notion": ("notion" in enabled_content),
"plaintext": ("plaintext" in enabled_content),
"enable_offline_model": False,
"conversation_openai": False,
"conversation_gpt4all": False,
@ -56,13 +79,7 @@ if not state.demo:
if state.content_index:
successfully_configured.update(
{
"pdf": state.content_index.pdf is not None,
"markdown": state.content_index.markdown is not None,
"org": state.content_index.org is not None,
"image": state.content_index.image is not None,
"github": state.content_index.github is not None,
"notion": state.content_index.notion is not None,
"plaintext": state.content_index.plaintext is not None,
}
)
@ -84,22 +101,29 @@ if not state.demo:
)
@web_client.get("/config/content_type/github", response_class=HTMLResponse)
@requires(["authenticated"])
def github_config_page(request: Request):
default_copy = constants.default_config.copy()
default_github = default_copy["content-type"]["github"] # type: ignore
user = request.user.object if request.user.is_authenticated else None
current_github_config = get_user_github_config(user)
default_config = TextContentConfig(
compressed_jsonl=default_github["compressed-jsonl"],
embeddings_file=default_github["embeddings-file"],
if current_github_config:
raw_repos = current_github_config.githubrepoconfig.all()
repos = []
for repo in raw_repos:
repos.append(
GithubRepoConfig(
name=repo.name,
owner=repo.owner,
branch=repo.branch,
)
current_config = (
state.config.content_type.github
if state.config and state.config.content_type and state.config.content_type.github
else default_config
)
current_config = GithubContentConfig(
pat_token=current_github_config.pat_token,
repos=repos,
)
current_config = json.loads(current_config.json())
else:
current_config = {} # type: ignore
return templates.TemplateResponse(
"content_type_github_input.html", context={"request": request, "current_config": current_config}
@ -107,18 +131,11 @@ if not state.demo:
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
def notion_config_page(request: Request):
default_copy = constants.default_config.copy()
default_notion = default_copy["content-type"]["notion"] # type: ignore
user = request.user.object if request.user.is_authenticated else None
current_notion_config = get_user_notion_config(user)
default_config = TextContentConfig(
compressed_jsonl=default_notion["compressed-jsonl"],
embeddings_file=default_notion["embeddings-file"],
)
current_config = (
state.config.content_type.notion
if state.config and state.config.content_type and state.config.content_type.notion
else default_config
current_config = NotionContentConfig(
token=current_notion_config.token if current_notion_config else "",
)
current_config = json.loads(current_config.json())
@ -132,18 +149,16 @@ if not state.demo:
if content_type not in VALID_TEXT_CONTENT_TYPES:
return templates.TemplateResponse("config.html", context={"request": request})
default_copy = constants.default_config.copy()
default_content_type = default_copy["content-type"][content_type] # type: ignore
object = map_config_to_object(content_type)
user = request.user.object if request.user.is_authenticated else None
config = object.objects.filter(user=user).first()
if config == None:
config = object.objects.create(user=user)
default_config = TextContentConfig(
compressed_jsonl=default_content_type["compressed-jsonl"],
embeddings_file=default_content_type["embeddings-file"],
)
current_config = (
state.config.content_type[content_type]
if state.config and state.config.content_type and state.config.content_type[content_type] # type: ignore
else default_config
current_config = TextContentConfig(
input_files=config.input_files,
input_filter=config.input_filter,
index_heading_entries=config.index_heading_entries,
)
current_config = json.loads(current_config.json())

View file

@ -1,16 +1,9 @@
# Standard Packages
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
# Internal Packages
from khoj.utils.rawconfig import Entry
from typing import List
class BaseFilter(ABC):
@abstractmethod
def load(self, entries: List[Entry], *args, **kwargs):
...
@abstractmethod
def get_filter_terms(self, query: str) -> List[str]:
...
@ -18,10 +11,6 @@ class BaseFilter(ABC):
def can_filter(self, raw_query: str) -> bool:
return len(self.get_filter_terms(raw_query)) > 0
@abstractmethod
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
...
@abstractmethod
def defilter(self, query: str) -> str:
...

View file

@ -25,72 +25,42 @@ class DateFilter(BaseFilter):
# - dt>="last week"
# - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']"
raw_date_regex = r"\d{4}-\d{2}-\d{2}"
def __init__(self, entry_key="compiled"):
self.entry_key = entry_key
self.date_to_entry_ids = defaultdict(set)
self.cache = LRU()
def load(self, entries, *args, **kwargs):
with timer("Created date filter index", logger):
for id, entry in enumerate(entries):
# Extract dates from entry
for date_in_entry_string in re.findall(r"\d{4}-\d{2}-\d{2}", getattr(entry, self.entry_key)):
# Convert date string in entry to unix timestamp
def extract_dates(self, content):
pattern_matched_dates = re.findall(self.raw_date_regex, content)
# Filter down to valid dates
valid_dates = []
for date_str in pattern_matched_dates:
try:
date_in_entry = datetime.strptime(date_in_entry_string, "%Y-%m-%d").timestamp()
valid_dates.append(datetime.strptime(date_str, "%Y-%m-%d"))
except ValueError:
continue
except OSError:
logger.debug(f"OSError: Ignoring unprocessable date in entry: {date_in_entry_string}")
continue
self.date_to_entry_ids[date_in_entry].add(id)
return valid_dates
def get_filter_terms(self, query: str) -> List[str]:
"Get all filter terms in query"
return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)]
def get_query_date_range(self, query) -> List:
with timer("Extract date range to filter from query", logger):
query_daterange = self.extract_date_range(query)
return query_daterange
def defilter(self, query):
# remove date range filter from query
query = re.sub(rf"\s+{self.date_regex}", " ", query)
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
return query
def apply(self, query, entries):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
with timer("Extract date range to filter from query", logger):
query_daterange = self.extract_date_range(query)
# if no date in query, return all entries
if query_daterange == []:
return query, set(range(len(entries)))
query = self.defilter(query)
# return results from cache if exists
cache_key = tuple(query_daterange)
if cache_key in self.cache:
logger.debug(f"Return date filter results from cache")
entries_to_include = self.cache[cache_key]
return query, entries_to_include
if not self.date_to_entry_ids:
self.load(entries)
# find entries containing any dates that fall with date range specified in query
with timer("Mark entries satisfying filter", logger):
entries_to_include = set()
for date_in_entry in self.date_to_entry_ids.keys():
# Check if date in entry is within date range specified in query
if query_daterange[0] <= date_in_entry < query_daterange[1]:
entries_to_include |= self.date_to_entry_ids[date_in_entry]
# cache results
self.cache[cache_key] = entries_to_include
return query, entries_to_include
def extract_date_range(self, query):
# find date range filter in query
date_range_matches = re.findall(self.date_regex, query)
@ -138,6 +108,15 @@ class DateFilter(BaseFilter):
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
return []
else:
# If the first element is 0, replace it with None
if effective_date_range[0] == 0:
effective_date_range[0] = None
# If the second element is inf, replace it with None
if effective_date_range[1] == inf:
effective_date_range[1] = None
return effective_date_range
def parse(self, date_str, relative_base=None):

View file

@ -21,62 +21,13 @@ class FileFilter(BaseFilter):
self.file_to_entry_map = defaultdict(set)
self.cache = LRU()
def load(self, entries, *args, **kwargs):
with timer("Created file filter index", logger):
for id, entry in enumerate(entries):
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
def get_filter_terms(self, query: str) -> List[str]:
"Get all filter terms in query"
return [f'file:"{term}"' for term in re.findall(self.file_filter_regex, query)]
return [f"{self.convert_to_regex(term)}" for term in re.findall(self.file_filter_regex, query)]
def convert_to_regex(self, file_filter: str) -> str:
"Convert file filter to regex"
return file_filter.replace(".", r"\.").replace("*", r".*")
def defilter(self, query: str) -> str:
return re.sub(self.file_filter_regex, "", query).strip()
def apply(self, query, entries):
# Extract file filters from raw query
with timer("Extract files_to_search from query", logger):
raw_files_to_search = re.findall(self.file_filter_regex, query)
if not raw_files_to_search:
return query, set(range(len(entries)))
# Convert simple file filters with no path separator into regex
# e.g. "file:notes.org" -> "file:.*notes.org"
files_to_search = []
for file in sorted(raw_files_to_search):
if "/" not in file and "\\" not in file and "*" not in file:
files_to_search += [f"*{file}"]
else:
files_to_search += [file]
# Remove filter terms from original query
query = self.defilter(query)
# Return item from cache if exists
cache_key = tuple(files_to_search)
if cache_key in self.cache:
logger.debug(f"Return file filter results from cache")
included_entry_indices = self.cache[cache_key]
return query, included_entry_indices
if not self.file_to_entry_map:
self.load(entries, regenerate=False)
# Mark entries that contain any blocked_words for exclusion
with timer("Mark entries satisfying filter", logger):
included_entry_indices = set.union(
*[
self.file_to_entry_map[entry_file]
for entry_file in self.file_to_entry_map.keys()
for search_file in files_to_search
if fnmatch.fnmatch(entry_file, search_file)
],
set(),
)
if not included_entry_indices:
return query, {}
# Cache results
self.cache[cache_key] = included_entry_indices
return query, included_entry_indices

View file

@ -6,7 +6,7 @@ from typing import List
# Internal Packages
from khoj.search_filter.base_filter import BaseFilter
from khoj.utils.helpers import LRU, timer
from khoj.utils.helpers import LRU
logger = logging.getLogger(__name__)
@ -22,21 +22,6 @@ class WordFilter(BaseFilter):
self.word_to_entry_index = defaultdict(set)
self.cache = LRU()
def load(self, entries, *args, **kwargs):
with timer("Created word filter index", logger):
self.cache = {} # Clear cache on filter (re-)load
entry_splitter = (
r",|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'"
)
# Create map of words to entries they exist in
for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
if word == "":
continue
self.word_to_entry_index[word].add(entry_index)
return self.word_to_entry_index
def get_filter_terms(self, query: str) -> List[str]:
"Get all filter terms in query"
required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)]
@ -45,47 +30,3 @@ class WordFilter(BaseFilter):
def defilter(self, query: str) -> str:
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
def apply(self, query, entries):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters
with timer("Extract required, blocked filters from query", logger):
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
query = self.defilter(query)
if len(required_words) == 0 and len(blocked_words) == 0:
return query, set(range(len(entries)))
# Return item from cache if exists
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
if cache_key in self.cache:
logger.debug(f"Return word filter results from cache")
included_entry_indices = self.cache[cache_key]
return query, included_entry_indices
if not self.word_to_entry_index:
self.load(entries, regenerate=False)
# mark entries that contain all required_words for inclusion
with timer("Mark entries satisfying filter", logger):
entries_with_all_required_words = set(range(len(entries)))
if len(required_words) > 0:
entries_with_all_required_words = set.intersection(
*[self.word_to_entry_index.get(word, set()) for word in required_words]
)
# mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set()
if len(blocked_words) > 0:
entries_with_any_blocked_words = set.union(
*[self.word_to_entry_index.get(word, set()) for word in blocked_words]
)
# get entries satisfying inclusion and exclusion filters
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
# Cache results
self.cache[cache_key] = included_entry_indices
return query, included_entry_indices

View file

@ -2,25 +2,39 @@
import logging
import math
from pathlib import Path
from typing import List, Tuple, Type, Union
from typing import List, Tuple, Type, Union, Dict
# External Packages
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.search_filter.base_filter import BaseFilter
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
from khoj.utils.config import TextContent, TextSearchModel
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
from khoj.utils.config import TextSearchModel
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
from khoj.utils.state import SearchType
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, Entry
from khoj.utils.jsonl import load_jsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from database.adapters import EmbeddingsAdapters
from database.models import KhojUser, Embeddings
logger = logging.getLogger(__name__)
search_type_to_embeddings_type = {
SearchType.Org.value: Embeddings.EmbeddingsType.ORG,
SearchType.Markdown.value: Embeddings.EmbeddingsType.MARKDOWN,
SearchType.Plaintext.value: Embeddings.EmbeddingsType.PLAINTEXT,
SearchType.Pdf.value: Embeddings.EmbeddingsType.PDF,
SearchType.Github.value: Embeddings.EmbeddingsType.GITHUB,
SearchType.Notion.value: Embeddings.EmbeddingsType.NOTION,
SearchType.All.value: None,
}
def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text"
@ -117,171 +131,102 @@ def load_embeddings(
async def query(
user: KhojUser,
raw_query: str,
search_model: TextSearchModel,
content: TextContent,
type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False,
score_threshold: float = -math.inf,
dedupe: bool = True,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
if (
content.entries is None
or len(content.entries) == 0
or content.corpus_embeddings is None
or len(content.corpus_embeddings) == 0
):
return [], []
query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
file_type = search_type_to_embeddings_type[type.value]
# Filter query, entries and embeddings before semantic search
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
# If no entries left after filtering, return empty results
if entries is None or len(entries) == 0:
return [], []
# If query only had filters it'll be empty now. So short-circuit and return results.
if query.strip() == "":
hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)]
return hits, entries
query = raw_query
# Encode the query using the bi-encoder
if question_embedding is None:
with timer("Query Encode Time", logger, state.device):
question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
question_embedding = state.embeddings_model.embed_query(query)
# Find relevant entries for the query
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
top_k = 10
with timer("Search Time", logger, state.device):
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
hits = EmbeddingsAdapters.search_with_embeddings(
user=user,
embeddings=question_embedding,
max_results=top_k,
file_type_filter=file_type,
raw_query=raw_query,
).all()
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
# Score all retrieved entries using the cross-encoder
if rank_results and search_model.cross_encoder:
hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
# Filter results by score threshold
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
# Order results by cross-encoder score followed by bi-encoder score
hits = sort_results(rank_results, hits)
# Deduplicate entries by raw entry text before showing to users
if dedupe:
hits = deduplicate_results(entries, hits)
return hits, entries
return hits
def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]:
return [
SearchResponse.parse_obj(
def collate_results(hits, dedupe=True):
hit_ids = set()
for hit in hits:
if dedupe and hit.corpus_id in hit_ids:
continue
else:
hit_ids.add(hit.corpus_id)
yield SearchResponse.parse_obj(
{
"entry": entries[hit["corpus_id"]].raw,
"score": f"{hit.get('cross-score') or hit.get('score')}",
"entry": hit.raw,
"score": hit.distance,
"additional": {
"file": entries[hit["corpus_id"]].file,
"compiled": entries[hit["corpus_id"]].compiled,
"heading": entries[hit["corpus_id"]].heading,
"file": hit.file_path,
"compiled": hit.compiled,
"heading": hit.heading,
},
}
)
for hit in hits[0:count]
]
def rerank_and_sort_results(hits, query):
# Score all retrieved entries using the cross-encoder
hits = cross_encoder_score(query, hits)
# Sort results by cross-encoder score followed by bi-encoder score
hits = sort_results(rank_results=True, hits=hits)
return hits
def setup(
text_to_jsonl: Type[TextToJsonl],
text_to_jsonl: Type[TextEmbeddings],
files: dict[str, str],
config: TextConfigBase,
bi_encoder: BaseEncoder,
regenerate: bool,
filters: List[BaseFilter] = [],
normalize: bool = True,
full_corpus: bool = True,
) -> TextContent:
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
previous_entries = []
if config.compressed_jsonl.exists() and not regenerate:
previous_entries = extract_entries(config.compressed_jsonl)
entries_with_indices = text_to_jsonl(config).process(
previous_entries=previous_entries, files=files, full_corpus=full_corpus
user: KhojUser = None,
config=None,
) -> None:
if config:
num_new_embeddings, num_deleted_embeddings = text_to_jsonl(config).process(
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
)
# Extract Updated Entries
entries = extract_entries(config.compressed_jsonl)
if is_none_or_empty(entries):
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
raise ValueError(
f"No valid entries found in specified configuration: {config_params}, with files: {files.keys()}"
)
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize
)
for filter in filters:
filter.load(entries, regenerate=regenerate)
return TextContent(entries, corpus_embeddings, filters)
def load(
config: TextConfigBase,
filters: List[BaseFilter] = [],
) -> TextContent:
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
entries = extract_entries(config.compressed_jsonl)
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = load_embeddings(config.embeddings_file)
for filter in filters:
filter.load(entries, regenerate=False)
return TextContent(entries, corpus_embeddings, filters)
def apply_filters(
query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]
) -> Tuple[str, List[Entry], torch.Tensor]:
"""Filter query, entries and embeddings before semantic search"""
with timer("Total Filter Time", logger, state.device):
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
for filter in filters_in_query:
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return "", [], torch.tensor([], device=state.device)
else:
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)
num_new_embeddings, num_deleted_embeddings = text_to_jsonl().process(
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
)
return query, entries, corpus_embeddings
file_names = [file_name for file_name in files]
logger.info(
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
)
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]:
def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:
"""Score all retrieved entries using the cross-encoder"""
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_inp = [[query, entries[hit["corpus_id"]].compiled] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
cross_scores = state.cross_encoder_model.predict(query, hits)
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
hits[idx]["cross-score"] = cross_scores[idx]
hits[idx]["cross_score"] = cross_scores[idx]
return hits
@ -291,23 +236,5 @@ def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x["cross-score"], reverse=True) # sort by cross-encoder score
return hits
def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]:
"""Deduplicate entries by raw entry text before showing to users
Compiled entries are split by max tokens supported by ML models.
This can result in duplicate hits, entries shown to user."""
with timer("Deduplication Time", logger, state.device):
seen, original_hits_count = set(), len(hits)
hits = [
hit
for hit in hits
if entries[hit["corpus_id"]].raw not in seen and not seen.add(entries[hit["corpus_id"]].raw) # type: ignore[func-returns-value]
]
duplicate_hits = original_hits_count - len(hits)
logger.debug(f"Removed {duplicate_hits} duplicates")
hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
return hits

View file

@ -2,6 +2,7 @@
import argparse
import pathlib
from importlib.metadata import version
import os
# Internal Packages
from khoj.utils.helpers import resolve_absolute_path
@ -34,6 +35,12 @@ def cli(args=None):
)
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
parser.add_argument("--demo", action="store_true", default=False, help="Run Khoj in demo mode")
parser.add_argument(
"--anonymous-mode",
action="store_true",
default=False,
help="Run Khoj in anonymous mode. This does not require any login for connecting users.",
)
args = parser.parse_args(args)
@ -51,6 +58,8 @@ def cli(args=None):
else:
args = run_migrations(args)
args.config = parse_config_from_file(args.config_file)
if os.environ.get("DEBUG"):
args.config.app.should_log_telemetry = False
return args

View file

@ -41,9 +41,7 @@ class ProcessorType(str, Enum):
@dataclass
class TextContent:
entries: List[Entry]
corpus_embeddings: torch.Tensor
filters: List[BaseFilter]
enabled: bool
@dataclass
@ -67,21 +65,13 @@ class ImageSearchModel:
@dataclass
class ContentIndex:
org: Optional[TextContent] = None
markdown: Optional[TextContent] = None
pdf: Optional[TextContent] = None
github: Optional[TextContent] = None
notion: Optional[TextContent] = None
image: Optional[ImageContent] = None
plaintext: Optional[TextContent] = None
plugins: Optional[Dict[str, TextContent]] = None
@dataclass
class SearchModels:
text_search: Optional[TextSearchModel] = None
image_search: Optional[ImageSearchModel] = None
plugin_search: Optional[Dict[str, TextSearchModel]] = None
@dataclass

View file

@ -5,6 +5,7 @@ web_directory = app_root_directory / "khoj/interface/web/"
empty_escape_sequences = "\n|\r|\t| "
app_env_filepath = "~/.khoj/env"
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
content_directory = "~/.khoj/content/"
empty_config = {
"content-type": {

View file

@ -5,27 +5,37 @@ from typing import Optional
from bs4 import BeautifulSoup
from khoj.utils.helpers import get_absolute_path, is_none_or_empty
from khoj.utils.rawconfig import TextContentConfig, ContentConfig
from khoj.utils.rawconfig import TextContentConfig
from khoj.utils.config import SearchType
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig
logger = logging.getLogger(__name__)
def collect_files(config: ContentConfig, search_type: Optional[SearchType] = SearchType.All):
def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict:
files = {}
if config is None:
if search_type == SearchType.All or search_type == SearchType.Org:
org_config = LocalOrgConfig.objects.filter(user=user).first()
files["org"] = get_org_files(construct_config_from_db(org_config)) if org_config else {}
if search_type == SearchType.All or search_type == SearchType.Markdown:
markdown_config = LocalMarkdownConfig.objects.filter(user=user).first()
files["markdown"] = get_markdown_files(construct_config_from_db(markdown_config)) if markdown_config else {}
if search_type == SearchType.All or search_type == SearchType.Plaintext:
plaintext_config = LocalPlaintextConfig.objects.filter(user=user).first()
files["plaintext"] = get_plaintext_files(construct_config_from_db(plaintext_config)) if plaintext_config else {}
if search_type == SearchType.All or search_type == SearchType.Pdf:
pdf_config = LocalPdfConfig.objects.filter(user=user).first()
files["pdf"] = get_pdf_files(construct_config_from_db(pdf_config)) if pdf_config else {}
return files
if search_type == SearchType.All or search_type == SearchType.Org:
files["org"] = get_org_files(config.org) if config.org else {}
if search_type == SearchType.All or search_type == SearchType.Markdown:
files["markdown"] = get_markdown_files(config.markdown) if config.markdown else {}
if search_type == SearchType.All or search_type == SearchType.Plaintext:
files["plaintext"] = get_plaintext_files(config.plaintext) if config.plaintext else {}
if search_type == SearchType.All or search_type == SearchType.Pdf:
files["pdf"] = get_pdf_files(config.pdf) if config.pdf else {}
return files
def construct_config_from_db(db_config) -> TextContentConfig:
return TextContentConfig(
input_files=db_config.input_files,
input_filter=db_config.input_filter,
index_heading_entries=db_config.index_heading_entries,
)
def get_plaintext_files(config: TextContentConfig) -> dict[str, str]:

View file

@ -209,10 +209,12 @@ def log_telemetry(
if not app_config or not app_config.should_log_telemetry:
return []
if properties.get("server_id") is None:
properties["server_id"] = get_server_id()
# Populate telemetry data to log
request_body = {
"telemetry_type": telemetry_type,
"server_id": get_server_id(),
"server_version": version("khoj-assistant"),
"os": platform.system(),
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),

View file

@ -1,13 +1,14 @@
# System Packages
import json
from pathlib import Path
from typing import List, Dict, Optional, Union, Any
from typing import List, Dict, Optional
import uuid
# External Packages
from pydantic import BaseModel, validator
from pydantic import BaseModel
# Internal Packages
from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty
from khoj.utils.helpers import to_snake_case_from_dash
class ConfigBase(BaseModel):
@ -27,7 +28,7 @@ class TextConfigBase(ConfigBase):
embeddings_file: Path
class TextContentConfig(TextConfigBase):
class TextContentConfig(ConfigBase):
input_files: Optional[List[Path]]
input_filter: Optional[List[str]]
index_heading_entries: Optional[bool] = False
@ -39,12 +40,12 @@ class GithubRepoConfig(ConfigBase):
branch: Optional[str] = "master"
class GithubContentConfig(TextConfigBase):
class GithubContentConfig(ConfigBase):
pat_token: str
repos: List[GithubRepoConfig]
class NotionContentConfig(TextConfigBase):
class NotionContentConfig(ConfigBase):
token: str
@ -63,7 +64,6 @@ class ContentConfig(ConfigBase):
pdf: Optional[TextContentConfig]
plaintext: Optional[TextContentConfig]
github: Optional[GithubContentConfig]
plugins: Optional[Dict[str, TextContentConfig]]
notion: Optional[NotionContentConfig]
@ -122,7 +122,8 @@ class FullConfig(ConfigBase):
class SearchResponse(ConfigBase):
entry: str
score: str
score: float
cross_score: Optional[float]
additional: Optional[dict]
@ -131,14 +132,21 @@ class Entry:
compiled: str
heading: Optional[str]
file: Optional[str]
corpus_id: str
def __init__(
self, raw: str = None, compiled: str = None, heading: Optional[str] = None, file: Optional[str] = None
self,
raw: str = None,
compiled: str = None,
heading: Optional[str] = None,
file: Optional[str] = None,
corpus_id: uuid.UUID = None,
):
self.raw = raw
self.compiled = compiled
self.heading = heading
self.file = file
self.corpus_id = str(corpus_id)
def to_json(self) -> str:
return json.dumps(self.__dict__, ensure_ascii=False)
@ -153,4 +161,5 @@ class Entry:
compiled=dictionary["compiled"],
file=dictionary.get("file", None),
heading=dictionary.get("heading", None),
corpus_id=dictionary.get("corpus_id", None),
)

View file

@ -2,6 +2,7 @@
import threading
from typing import List, Dict
from packaging import version
from collections import defaultdict
# External Packages
import torch
@ -12,10 +13,13 @@ from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import FullConfig
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
# Application Global State
config = FullConfig()
search_models = SearchModels()
embeddings_model = EmbeddingsModel()
cross_encoder_model = CrossEncoderModel()
content_index = ContentIndex()
processor_config = ProcessorConfigModel()
config_file: Path = None
@ -23,14 +27,14 @@ verbose: int = 0
host: str = None
port: int = None
cli_args: List[str] = None
query_cache = LRU()
query_cache: Dict[str, LRU] = defaultdict(LRU)
config_lock = threading.Lock()
chat_lock = threading.Lock()
SearchType = utils_config.SearchType
telemetry: List[Dict[str, str]] = []
previous_query: str = None
demo: bool = False
khoj_version: str = None
anonymous_mode: bool = False
if torch.cuda.is_available():
# Use CUDA GPU

View file

@ -1,15 +1,19 @@
# External Packages
import os
from copy import deepcopy
from fastapi.testclient import TestClient
from pathlib import Path
import pytest
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI
import factory
import os
from fastapi import FastAPI
app = FastAPI()
# Internal Packages
from app.main import app
from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchModels
@ -22,8 +26,6 @@ from khoj.utils.rawconfig import (
OpenAIProcessorConfig,
ProcessorConfig,
TextContentConfig,
GithubContentConfig,
GithubRepoConfig,
ImageContentConfig,
SearchConfig,
TextSearchConfig,
@ -31,11 +33,31 @@ from khoj.utils.rawconfig import (
)
from khoj.utils import state, fs_syncer
from khoj.routers.indexer import configure_content
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from database.models import (
LocalOrgConfig,
LocalMarkdownConfig,
LocalPlaintextConfig,
LocalPdfConfig,
GithubConfig,
KhojUser,
GithubRepoConfig,
)
@pytest.fixture(autouse=True)
def enable_db_access_for_all_tests(db):
pass
class UserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojUser
username = factory.Faker("name")
email = factory.Faker("email")
password = factory.Faker("password")
uuid = factory.Faker("uuid4")
@pytest.fixture(scope="session")
@ -67,17 +89,28 @@ def search_config() -> SearchConfig:
return search_config
@pytest.mark.django_db
@pytest.fixture
def default_user():
return UserFactory()
@pytest.fixture(scope="session")
def search_models(search_config: SearchConfig):
search_models = SearchModels()
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
@pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
@pytest.fixture
def anyio_backend():
return "asyncio"
@pytest.mark.django_db
@pytest.fixture(scope="function")
def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser):
content_dir = tmp_path_factory.mktemp("content")
# Generate Image Embeddings from Test Images
@ -92,94 +125,45 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
# Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig(
LocalOrgConfig.objects.create(
input_files=None,
input_filter=["tests/data/org/*.org"],
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
index_heading_entries=False,
user=default_user,
)
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(
OrgToJsonl,
get_sample_data("org"),
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
content_config.plugins = {
"plugin1": TextContentConfig(
input_files=[content_dir.joinpath("notes.jsonl.gz")],
input_filter=None,
compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
)
}
text_search.setup(OrgToJsonl, get_sample_data("org"), regenerate=False, user=default_user)
if os.getenv("GITHUB_PAT_TOKEN"):
content_config.github = GithubContentConfig(
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
repos=[
GithubRepoConfig(
GithubConfig.objects.create(
pat_token=os.getenv("GITHUB_PAT_TOKEN"),
user=default_user,
)
GithubRepoConfig.objects.create(
owner="khoj-ai",
name="lantern",
branch="master",
)
],
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
github_config=GithubConfig.objects.get(user=default_user),
)
content_config.plaintext = TextContentConfig(
LocalPlaintextConfig.objects.create(
input_files=None,
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
compressed_jsonl=content_dir.joinpath("plaintext.jsonl.gz"),
embeddings_file=content_dir.joinpath("plaintext_embeddings.pt"),
)
content_config.github = GithubContentConfig(
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
repos=[
GithubRepoConfig(
owner="khoj-ai",
name="lantern",
branch="master",
)
],
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
)
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(
JsonlToJsonl,
None,
content_config.plugins["plugin1"],
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
user=default_user,
)
return content_config
@pytest.fixture(scope="session")
def md_content_config(tmp_path_factory):
content_dir = tmp_path_factory.mktemp("content")
# Generate Embeddings for Markdown Content
content_config = ContentConfig()
content_config.markdown = TextContentConfig(
def md_content_config():
markdown_config = LocalMarkdownConfig.objects.create(
input_files=None,
input_filter=["tests/data/markdown/*.markdown"],
compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"),
embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
)
return content_config
return markdown_config
@pytest.fixture(scope="session")
@ -220,19 +204,20 @@ def processor_config_offline_chat(tmp_path_factory):
@pytest.fixture(scope="session")
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
# Initialize app state
state.config.content_type = md_content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
# Index Markdown Content for Search
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
all_files = fs_syncer.collect_files(state.config.content_type)
all_files = fs_syncer.collect_files()
state.content_index = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models
)
# Initialize Processor from Config
state.processor_config = configure_processor(processor_config)
state.anonymous_mode = True
app = FastAPI()
configure_routes(app)
configure_middleware(app)
@ -241,33 +226,45 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
@pytest.fixture(scope="function")
def client(content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
def fastapi_app():
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
return app
@pytest.fixture(scope="function")
def client(
content_config: ContentConfig,
search_config: SearchConfig,
processor_config: ProcessorConfig,
default_user: KhojUser,
):
state.config.content_type = content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
# These lines help us Mock the Search models for these search types
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
state.content_index.org = text_search.setup(
text_search.setup(
OrgToJsonl,
get_sample_data("org"),
content_config.org,
state.search_models.text_search.bi_encoder,
regenerate=False,
user=default_user,
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
state.content_index.plaintext = text_search.setup(
text_search.setup(
PlaintextToJsonl,
get_sample_data("plaintext"),
content_config.plaintext,
state.search_models.text_search.bi_encoder,
regenerate=False,
user=default_user,
)
state.processor_config = configure_processor(processor_config)
state.anonymous_mode = True
configure_routes(app)
configure_middleware(app)
@ -288,7 +285,6 @@ def client_offline_chat(
state.SearchType = configure_search_types(state.config)
# Index Markdown Content for Search
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
all_files = fs_syncer.collect_files(state.config.content_type)
@ -298,6 +294,7 @@ def client_offline_chat(
# Initialize Processor from Config
state.processor_config = configure_processor(processor_config_offline_chat)
state.anonymous_mode = True
configure_routes(app)
configure_middleware(app)
@ -306,9 +303,11 @@ def client_offline_chat(
@pytest.fixture(scope="function")
def new_org_file(content_config: ContentConfig):
def new_org_file(default_user: KhojUser, content_config: ContentConfig):
# Setup
new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org"
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
input_filters = org_config.input_filter
new_org_file = Path(input_filters[0]).parent / "new_file.org"
new_org_file.touch()
yield new_org_file
@ -319,11 +318,9 @@ def new_org_file(content_config: ContentConfig):
@pytest.fixture(scope="function")
def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path):
new_org_config = deepcopy(content_config.org)
new_org_config.input_files = [f"{new_org_file}"]
new_org_config.input_filter = None
return new_org_config
def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None)
return LocalOrgConfig.objects.filter(user=default_user).first()
@pytest.fixture(scope="function")

11
tests/data/config.yml vendored
View file

@ -9,17 +9,6 @@ content-type:
input-filter:
- '*.org'
- ~/notes/*.org
plugins:
content_plugin_1:
compressed-jsonl: content_plugin_1.jsonl.gz
embeddings-file: content_plugin_1_embeddings.pt
input-files:
- content_plugin_1_new.jsonl.gz
content_plugin_2:
compressed-jsonl: content_plugin_2.jsonl.gz
embeddings-file: content_plugin_2_embeddings.pt
input-filter:
- '*2_new.jsonl.gz'
enable-offline-chat: false
search-type:
asymmetric:

View file

@ -48,14 +48,3 @@ def test_cli_config_from_file():
Path("~/first_from_config.org"),
Path("~/second_from_config.org"),
]
assert len(actual_args.config.content_type.plugins.keys()) == 2
assert actual_args.config.content_type.plugins["content_plugin_1"].input_files == [
Path("content_plugin_1_new.jsonl.gz")
]
assert actual_args.config.content_type.plugins["content_plugin_2"].input_filter == ["*2_new.jsonl.gz"]
assert actual_args.config.content_type.plugins["content_plugin_1"].compressed_jsonl == Path(
"content_plugin_1.jsonl.gz"
)
assert actual_args.config.content_type.plugins["content_plugin_2"].embeddings_file == Path(
"content_plugin_2_embeddings.pt"
)

View file

@ -2,22 +2,21 @@
from io import BytesIO
from PIL import Image
from urllib.parse import quote
import pytest
# External Packages
from fastapi.testclient import TestClient
import pytest
from fastapi import FastAPI
# Internal Packages
from app.main import app
from khoj.configure import configure_routes, configure_search_types
from khoj.utils import state
from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from database.models import KhojUser
from database.adapters import EmbeddingsAdapters
# Test
@ -35,7 +34,7 @@ def test_search_with_invalid_content_type(client):
# ----------------------------------------------------------------------------------------------------
def test_search_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plugin1"]:
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion"]:
# Act
response = client.get(f"/api/search?q=random&t={content_type}")
# Assert
@ -75,7 +74,7 @@ def test_index_update(client):
# ----------------------------------------------------------------------------------------------------
def test_regenerate_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]:
for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]:
# Arrange
files = get_sample_files_data()
headers = {"x-api-key": "secret"}
@ -102,60 +101,42 @@ def test_regenerate_with_github_fails_without_pat(client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
@pytest.mark.skip(reason="Flaky test on parallel test runs")
def test_get_configured_types_via_api(client):
def test_get_configured_types_via_api(client, sample_org_data):
# Act
response = client.get(f"/api/config/types")
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
# Assert
assert response.status_code == 200
assert response.json() == ["all", "org", "image", "plaintext", "plugin1"]
assert list(enabled_types) == ["org"]
# ----------------------------------------------------------------------------------------------------
def test_get_configured_types_with_only_plugin_content_config(content_config):
@pytest.mark.django_db(transaction=True)
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data):
# Arrange
config.content_type = ContentConfig()
config.content_type.plugins = content_config.plugins
state.SearchType = configure_search_types(config)
configure_routes(app)
client = TestClient(app)
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
# Act
response = client.get(f"/api/config/types")
# Assert
assert response.status_code == 200
assert response.json() == ["all", "plugin1"]
assert response.json() == ["all", "org", "markdown", "image"]
# ----------------------------------------------------------------------------------------------------
def test_get_configured_types_with_no_plugin_content_config(content_config):
@pytest.mark.django_db(transaction=True)
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
# Arrange
config.content_type = content_config
config.content_type.plugins = None
state.SearchType = configure_search_types(config)
original_config = state.config.content_type
state.config.content_type = None
configure_routes(app)
client = TestClient(app)
# Act
response = client.get(f"/api/config/types")
# Assert
assert response.status_code == 200
assert "plugin1" not in response.json()
# ----------------------------------------------------------------------------------------------------
def test_get_configured_types_with_no_content_config():
# Arrange
config.content_type = ContentConfig()
state.SearchType = configure_search_types(config)
configure_routes(app)
client = TestClient(app)
configure_routes(fastapi_app)
client = TestClient(fastapi_app)
# Act
response = client.get(f"/api/config/types")
@ -164,6 +145,9 @@ def test_get_configured_types_with_no_content_config():
assert response.status_code == 200
assert response.json() == ["all"]
# Restore
state.config.content_type = original_config
# ----------------------------------------------------------------------------------------------------
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
@ -192,12 +176,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data):
@pytest.mark.django_db(transaction=True)
def test_notes_search(client, search_config: SearchConfig, sample_org_data):
# Arrange
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
user_query = quote("How to git install application?")
# Act
@ -211,19 +193,15 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_only_filters(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
):
# Arrange
filters = [WordFilter(), FileFilter()]
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
text_search.setup(
OrgToJsonl,
sample_org_data,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
user_query = quote('+"Emacs" file:"*.org"')
@ -238,15 +216,10 @@ def test_notes_search_with_only_filters(
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_include_filter(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
):
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_include_filter(client, sample_org_data):
# Arrange
filters = [WordFilter()]
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search, regenerate=False, filters=filters
)
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
user_query = quote('How to git install application? +"Emacs"')
# Act
@ -260,19 +233,13 @@ def test_notes_search_with_include_filter(
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_exclude_filter(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
):
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_exclude_filter(client, sample_org_data):
# Arrange
filters = [WordFilter()]
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
text_search.setup(
OrgToJsonl,
sample_org_data,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
user_query = quote('How to git install application? -"clone"')
@ -286,6 +253,22 @@ def test_notes_search_with_exclude_filter(
assert "clone" not in search_result
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
# Arrange
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
user_query = quote("How to git install application?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
# Assert
assert response.status_code == 200
# assert actual response has no data as the default_user is different from the user making the query (anonymous)
assert len(response.json()) == 0
def get_sample_files_data():
return {
"files": ("path/to/filename.org", "* practicing piano", "text/org"),

View file

@ -1,53 +1,12 @@
# Standard Packages
import re
from datetime import datetime
from math import inf
# External Packages
import pytest
# Internal Packages
from khoj.search_filter.date_filter import DateFilter
from khoj.utils.rawconfig import Entry
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
def test_date_filter():
entries = [
Entry(compiled="Entry with no date", raw="Entry with no date"),
Entry(compiled="April Fools entry: 1984-04-01", raw="April Fools entry: 1984-04-01"),
Entry(compiled="Entry with date:1984-04-02", raw="Entry with date:1984-04-02"),
]
q_with_no_date_filter = "head tail"
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2}
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
assert ret_query == "head tail"
assert entry_indices == set()
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {1}
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {1, 2}
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
@ -56,8 +15,8 @@ def test_extract_date_range():
datetime(1984, 1, 5, 0, 0, 0).timestamp(),
datetime(1984, 1, 7, 0, 0, 0).timestamp(),
]
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [None, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), None]
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
datetime(1984, 1, 1, 0, 0, 0).timestamp(),
datetime(1984, 1, 2, 0, 0, 0).timestamp(),

View file

@ -6,97 +6,73 @@ from khoj.utils.rawconfig import Entry
def test_no_file_filter():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = "head tail"
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == False
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_file_filter_with_non_existent_file():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"nonexistent.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {}
def test_single_file_filter():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"file 1.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 2}
def test_file_filter_with_partial_match():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"1.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 2}
def test_file_filter_with_regex_match():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"*.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_multiple_file_filter():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_get_file_filter_terms():
@ -108,7 +84,7 @@ def test_get_file_filter_terms():
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
# Assert
assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"']
assert filter_terms == ["file 1\\.org", "/path/to/dir/.*\\.org"]
def arrange_content():

View file

@ -1,78 +0,0 @@
# Internal Packages
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.utils.rawconfig import Entry
def test_process_entries_from_single_input_jsonl(tmp_path):
"Convert multiple jsonl entries from single file to entries."
# Arrange
input_jsonl = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}
{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}
"""
input_jsonl_file = create_file(tmp_path, input_jsonl)
# Act
# Process Each Entry from All Notes Files
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file])
entries = list(map(Entry.from_dict, input_jsons))
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
# Assert
assert len(entries) == 2
assert output_jsonl == input_jsonl
def test_process_entries_from_multiple_input_jsonls(tmp_path):
"Convert multiple jsonl entries from single file to entries."
# Arrange
input_jsonl_1 = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}"""
input_jsonl_2 = """{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}"""
input_jsonl_file_1 = create_file(tmp_path, input_jsonl_1, filename="input1.jsonl")
input_jsonl_file_2 = create_file(tmp_path, input_jsonl_2, filename="input2.jsonl")
# Act
# Process Each Entry from All Notes Files
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file_1, input_jsonl_file_2])
entries = list(map(Entry.from_dict, input_jsons))
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
# Assert
assert len(entries) == 2
assert output_jsonl == f"{input_jsonl_1}\n{input_jsonl_2}\n"
def test_get_jsonl_files(tmp_path):
"Ensure JSONL files specified via input-filter, input-files extracted"
# Arrange
# Include via input-filter globs
group1_file1 = create_file(tmp_path, filename="group1-file1.jsonl")
group1_file2 = create_file(tmp_path, filename="group1-file2.jsonl")
group2_file1 = create_file(tmp_path, filename="group2-file1.jsonl")
group2_file2 = create_file(tmp_path, filename="group2-file2.jsonl")
# Include via input-file field
file1 = create_file(tmp_path, filename="notes.jsonl")
# Not included by any filter
create_file(tmp_path, filename="not-included-jsonl.jsonl")
create_file(tmp_path, filename="not-included-text.txt")
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
# Setup input-files, input-filters
input_files = [tmp_path / "notes.jsonl"]
input_filter = [tmp_path / "group1*.jsonl", tmp_path / "group2*.jsonl"]
# Act
extracted_org_files = JsonlToJsonl.get_jsonl_files(input_files, input_filter)
# Assert
assert len(extracted_org_files) == 5
assert extracted_org_files == expected_files
# Helper Functions
def create_file(tmp_path, entry=None, filename="test.jsonl"):
jsonl_file = tmp_path / filename
jsonl_file.touch()
if entry:
jsonl_file.write_text(entry)
return jsonl_file

View file

@ -4,7 +4,7 @@ import os
# Internal Packages
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import Entry
from khoj.utils.fs_syncer import get_org_files
@ -63,7 +63,7 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
# Split each entry from specified Org files by max words
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
TextToJsonl.split_entries_by_max_tokens(
TextEmbeddings.split_entries_by_max_tokens(
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
)
)
@ -86,7 +86,7 @@ def test_entry_split_drops_large_words():
# Act
# Split entry by max words and drop words larger than max word length
processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0]
processed_entry = TextEmbeddings.split_entries_by_max_tokens([entry], max_word_length=5)[0]
# Assert
# "Heading" dropped from compiled version because its over the set max word limit

View file

@ -7,6 +7,7 @@ from pathlib import Path
from khoj.utils.fs_syncer import get_plaintext_files
from khoj.utils.rawconfig import TextContentConfig
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from database.models import LocalPlaintextConfig, KhojUser
def test_plaintext_file(tmp_path):
@ -91,11 +92,12 @@ def test_get_plaintext_files(tmp_path):
assert set(extracted_plaintext_files.keys()) == set(expected_files)
def test_parse_html_plaintext_file(content_config):
def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
"Ensure HTML files are parsed correctly"
# Arrange
# Setup input-files, input-filters
extracted_plaintext_files = get_plaintext_files(content_config.plaintext)
config = LocalPlaintextConfig.objects.filter(user=default_user).first()
extracted_plaintext_files = get_plaintext_files(config=config)
# Act
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files)

View file

@ -3,23 +3,30 @@ import logging
import locale
from pathlib import Path
import os
import asyncio
# External Packages
import pytest
# Internal Packages
from khoj.utils.state import content_index, search_models
from khoj.search_type import text_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.utils.config import SearchModels
from khoj.utils.fs_syncer import get_org_files
from khoj.utils.fs_syncer import get_org_files, collect_files
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
logger = logging.getLogger(__name__)
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
# Test
# ----------------------------------------------------------------------------------------------------
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig):
@pytest.mark.django_db
def test_text_search_setup_with_missing_file_raises_error(
org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig
):
# Arrange
# Ensure file mentioned in org.input-files is missing
single_new_file = Path(org_config_with_only_new_file.input_files[0])
@ -32,98 +39,126 @@ def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_n
# ----------------------------------------------------------------------------------------------------
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path: Path):
@pytest.mark.django_db
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser):
# Arrange
orgfile = tmp_path / "directory.org" / "file.org"
orgfile.parent.mkdir()
with open(orgfile, "w") as f:
f.write("* Heading\n- List item\n")
org_content_config = TextContentConfig(
input_filter=[f"{tmp_path}/**/*"], compressed_jsonl="test.jsonl", embeddings_file="test.pt"
LocalOrgConfig.objects.create(
input_filter=[f"{tmp_path}/**/*"],
input_files=None,
user=default_user,
)
org_files = collect_files(user=default_user)["org"]
# Act
# should not raise IsADirectoryError and return orgfile
assert get_org_files(org_content_config) == {f"{orgfile}": "* Heading\n- List item\n"}
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_with_empty_file_raises_error(
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange
data = get_org_files(org_config_with_only_new_file)
# Act
# Generate notes embeddings during asymmetric setup
with pytest.raises(ValueError, match=r"^No valid entries found*"):
text_search.setup(OrgToJsonl, data, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
verify_embeddings(0, default_user)
# ----------------------------------------------------------------------------------------------------
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
@pytest.mark.django_db
def test_text_search_setup(content_config, default_user: KhojUser, caplog):
# Arrange
data = get_org_files(content_config.org)
# Act
# Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# Assert
assert len(notes_model.entries) == 10
assert len(notes_model.corpus_embeddings) == 10
assert "Deleting all embeddings for file type org" in caplog.records[1].message
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
# ----------------------------------------------------------------------------------------------------
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog):
@pytest.mark.django_db
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
# Arrange
caplog.set_level(logging.INFO, logger="khoj")
data = get_org_files(content_config.org)
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
# Act
# Generate initial notes embeddings during asymmetric setup
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert
assert "Creating index from scratch." in initial_logs
assert "Creating index from scratch." not in final_logs
assert "Deleting all embeddings for file type org" in initial_logs
assert "Deleting all embeddings for file type org" not in final_logs
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
@pytest.mark.anyio
async def test_text_search(content_config: ContentConfig, search_config: SearchConfig):
# @pytest.mark.asyncio
async def test_text_search(search_config: SearchConfig):
# Arrange
data = get_org_files(content_config.org)
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
default_user = await KhojUser.objects.acreate(
username="test_user", password="test_password", email="test@example.com"
)
# Arrange
org_config = await LocalOrgConfig.objects.acreate(
input_files=None,
input_filter=["tests/data/org/*.org"],
index_heading_entries=False,
user=default_user,
)
data = get_org_files(org_config)
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
text_search.setup,
OrgToJsonl,
data,
True,
True,
default_user,
)
query = "How to git install application?"
# Act
hits, entries = await text_search.query(
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
)
results = text_search.collate_results(hits, entries, count=1)
hits = await text_search.query(default_user, query)
# Assert
results = text_search.collate_results(hits)
results = sorted(results, key=lambda x: float(x.score))[:1]
# search results should contain "git clone" entry
search_result = results[0].entry
assert "git clone" in search_result
# ----------------------------------------------------------------------------------------------------
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
@pytest.mark.django_db
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256
@ -137,47 +172,46 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# Act
# reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
# Assert
# verify newly added org-mode entry is split by max tokens
assert len(initial_notes_model.entries) == 2
assert len(initial_notes_model.corpus_embeddings) == 2
record = caplog.records[1]
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
# ----------------------------------------------------------------------------------------------------
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
@pytest.mark.django_db
def test_entry_chunking_by_max_tokens_not_full_corpus(
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
data = {
"readme.org": """
* Khoj
/Allow natural language search on user content like notes, images using transformer based models/
/Allow natural language search on user content like notes, images using transformer based models/
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
** Dependencies
- Python3
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
- Python3
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
** Install
#+begin_src shell
git clone https://github.com/khoj-ai/khoj && cd khoj
conda env create -f environment.yml
conda activate khoj
#+end_src"""
#+begin_src shell
git clone https://github.com/khoj-ai/khoj && cd khoj
conda env create -f environment.yml
conda activate khoj
#+end_src"""
}
text_search.setup(
OrgToJsonl,
data,
org_config_with_only_new_file,
search_models.text_search.bi_encoder,
regenerate=False,
user=default_user,
)
max_tokens = 256
@ -191,64 +225,57 @@ def test_entry_chunking_by_max_tokens_not_full_corpus(
# Act
# reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(
with caplog.at_level(logging.INFO):
text_search.setup(
OrgToJsonl,
data,
org_config_with_only_new_file,
search_models.text_search.bi_encoder,
regenerate=False,
full_corpus=False,
user=default_user,
)
record = caplog.records[1]
# Assert
# verify newly added org-mode entry is split by max tokens
assert len(initial_notes_model.entries) == 5
assert len(initial_notes_model.corpus_embeddings) == 5
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_regenerate_index_with_new_entry(
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
):
# Arrange
data = get_org_files(content_config.org)
initial_notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
# append org-mode entry to first org input file in config
content_config.org.input_files = [f"{new_org_file}"]
org_config.input_files = [f"{new_org_file}"]
with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
data = get_org_files(content_config.org)
data = get_org_files(org_config)
# Act
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# Assert
assert len(regenerated_notes_model.entries) == 11
assert len(regenerated_notes_model.corpus_embeddings) == 11
# verify new entry appended to index, without disrupting order or content of existing entries
error_details = compare_index(initial_notes_model, regenerated_notes_model)
if error_details:
pytest.fail(error_details, False)
# Cleanup
# reset input_files in config to empty list
content_config.org.input_files = []
assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message
verify_embeddings(11, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_update_index_with_duplicate_entries_in_stable_order(
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
@ -262,30 +289,26 @@ def test_update_index_with_duplicate_entries_in_stable_order(
# Act
# load embeddings, entries, notes model after adding new org-mode file
initial_index = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model after adding new org-mode file
updated_index = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert len(initial_index.entries) == len(updated_index.entries) == 1
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
# verify the same entry is added even when there are multiple duplicate entries
error_details = compare_index(initial_index, updated_index)
if error_details:
pytest.fail(error_details)
verify_embeddings(1, default_user)
# ----------------------------------------------------------------------------------------------------
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
@pytest.mark.django_db
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
# Arrange
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
@ -296,9 +319,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
data = get_org_files(org_config_with_only_new_file)
# load embeddings, entries, notes model after adding new org file with 2 entries
initial_index = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f:
@ -307,87 +329,65 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
data = get_org_files(org_config_with_only_new_file)
# Act
updated_index = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert len(initial_index.entries) == len(updated_index.entries) + 1
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message
# verify the same entry is added even when there are multiple duplicate entries
error_details = compare_index(updated_index, initial_index)
if error_details:
pytest.fail(error_details)
verify_embeddings(1, default_user)
# ----------------------------------------------------------------------------------------------------
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
@pytest.mark.django_db
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
# Arrange
data = get_org_files(content_config.org)
initial_notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
)
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# append org-mode entry to first org input file in config
with open(new_org_file, "w") as f:
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
f.write(new_entry)
data = get_org_files(content_config.org)
data = get_org_files(org_config)
# Act
# update embeddings, entries with the newly added note
content_config.org.input_files = [f"{new_org_file}"]
final_notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
# Assert
assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1
assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
# verify new entry appended to index, without disrupting order or content of existing entries
error_details = compare_index(initial_notes_model, final_notes_model)
if error_details:
pytest.fail(error_details, False)
# Cleanup
# reset input_files in config to empty list
content_config.org.input_files = []
verify_embeddings(11, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
def test_text_search_setup_github(content_config: ContentConfig, search_models: SearchModels):
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
# Arrange
github_config = GithubConfig.objects.filter(user=default_user).first()
# Act
# Regenerate github embeddings to test asymmetric setup without caching
github_model = text_search.setup(
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
text_search.setup(
GithubToJsonl,
{},
regenerate=True,
user=default_user,
config=github_config,
)
# Assert
assert len(github_model.entries) > 1
embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count()
assert embeddings > 1
def compare_index(initial_notes_model, final_notes_model):
mismatched_entries, mismatched_embeddings = [], []
for index in range(len(initial_notes_model.entries)):
if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
mismatched_entries.append(index)
# verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
for index in range(len(initial_notes_model.corpus_embeddings)):
if not initial_notes_model.corpus_embeddings[index].allclose(final_notes_model.corpus_embeddings[index]):
mismatched_embeddings.append(index)
error_details = ""
if mismatched_entries:
mismatched_entries_str = ",".join(map(str, mismatched_entries))
error_details += f"Entries at {mismatched_entries_str} not equal\n"
if mismatched_embeddings:
mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
return error_details
def verify_embeddings(expected_count, user):
embeddings = Embeddings.objects.filter(user=user, file_type="org").count()
assert embeddings == expected_count

View file

@ -3,68 +3,40 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.utils.rawconfig import Entry
def test_no_word_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
q_with_no_filter = "head tail"
# Act
can_filter = word_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == False
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_word_exclude_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
q_with_exclude_filter = 'head -"exclude_word" tail'
# Act
can_filter = word_filter.can_filter(q_with_exclude_filter)
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 2}
def test_word_include_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
query_with_include_filter = 'head +"include_word" tail'
# Act
can_filter = word_filter.can_filter(query_with_include_filter)
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {2, 3}
def test_word_include_and_exclude_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {2}
def test_get_word_filter_terms():