mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
[Multi-User Part 1]: Enable storage of settings for plaintext files based on user account (#498)
- Partition configuration for indexing local data based on user accounts - Store indexed data in an underlying postgres db using the `pgvector` extension - Add migrations for all relevant user data and embeddings generation. Very little performance optimization has been done for the lookup time - Apply filters using SQL queries - Start removing many server-level configuration settings - Configure GitHub test actions to run during any PR. Update the test action to run in a containerized environment with a DB. - Update the Docker image and docker-compose.yml to work with the new application design
This commit is contained in:
parent
963cd165eb
commit
216acf545f
60 changed files with 1827 additions and 1792 deletions
48
.github/workflows/pre-commit.yml
vendored
Normal file
48
.github/workflows/pre-commit.yml
vendored
Normal file
|
@ -0,0 +1,48 @@
|
|||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- src/**
|
||||
- tests/**
|
||||
- config/**
|
||||
- pyproject.toml
|
||||
- .pre-commit-config.yml
|
||||
- .github/workflows/test.yml
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths:
|
||||
- src/khoj/**
|
||||
- tests/**
|
||||
- config/**
|
||||
- pyproject.toml
|
||||
- .pre-commit-config.yml
|
||||
- .github/workflows/test.yml
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Run Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
|
||||
- name: ⏬️ Install Dependencies
|
||||
run: |
|
||||
sudo apt update && sudo apt install -y libegl1
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: ⬇️ Install Application
|
||||
run: pip install --upgrade .[dev]
|
||||
|
||||
- name: 🌡️ Validate Application
|
||||
run: pre-commit run --hook-stage manual --all
|
50
.github/workflows/test.yml
vendored
50
.github/workflows/test.yml
vendored
|
@ -2,10 +2,8 @@ name: test
|
|||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- 'master'
|
||||
paths:
|
||||
- src/khoj/**
|
||||
- src/**
|
||||
- tests/**
|
||||
- config/**
|
||||
- pyproject.toml
|
||||
|
@ -13,7 +11,7 @@ on:
|
|||
- .github/workflows/test.yml
|
||||
push:
|
||||
branches:
|
||||
- 'master'
|
||||
- master
|
||||
paths:
|
||||
- src/khoj/**
|
||||
- tests/**
|
||||
|
@ -26,6 +24,7 @@ jobs:
|
|||
test:
|
||||
name: Run Tests
|
||||
runs-on: ubuntu-latest
|
||||
container: ubuntu:jammy
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
@ -33,6 +32,17 @@ jobs:
|
|||
- '3.9'
|
||||
- '3.10'
|
||||
- '3.11'
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: ankane/pgvector
|
||||
env:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_USER: postgres
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
|
@ -43,17 +53,37 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
|
||||
- name: ⏬️ Install Dependencies
|
||||
- name: Install Git
|
||||
run: |
|
||||
sudo apt update && sudo apt install -y libegl1
|
||||
apt update && apt install -y git
|
||||
|
||||
- name: ⏬️ Install Dependencies
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
run: |
|
||||
apt update && apt install -y libegl1 sqlite3 libsqlite3-dev libsqlite3-0
|
||||
|
||||
- name: ⬇️ Install Postgres
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
run : |
|
||||
apt install -y postgresql postgresql-client && apt install -y postgresql-server-dev-14
|
||||
|
||||
- name: ⬇️ Install pip
|
||||
run: |
|
||||
apt install -y python3-pip
|
||||
python -m ensurepip --upgrade
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: ⬇️ Install Application
|
||||
run: pip install --upgrade .[dev]
|
||||
|
||||
- name: 🌡️ Validate Application
|
||||
run: pre-commit run --hook-stage manual --all
|
||||
run: sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && pip install --upgrade .[dev]
|
||||
|
||||
- name: 🧪 Test Application
|
||||
env:
|
||||
POSTGRES_HOST: postgres
|
||||
POSTGRES_PORT: 5432
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: postgres
|
||||
run: pytest
|
||||
timeout-minutes: 10
|
||||
|
|
11
Dockerfile
11
Dockerfile
|
@ -8,13 +8,20 @@ RUN apt update -y && apt -y install python3-pip git
|
|||
WORKDIR /app
|
||||
|
||||
# Install Application
|
||||
COPY . .
|
||||
COPY pyproject.toml .
|
||||
COPY README.md .
|
||||
RUN sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && \
|
||||
pip install --no-cache-dir .
|
||||
|
||||
# Copy Source Code
|
||||
COPY . .
|
||||
|
||||
# Set the PYTHONPATH environment variable in order for it to find the Django app.
|
||||
ENV PYTHONPATH=/app/src:$PYTHONPATH
|
||||
|
||||
# Run the Application
|
||||
# There are more arguments required for the application to run,
|
||||
# but these should be passed in through the docker-compose.yml file.
|
||||
ARG PORT
|
||||
EXPOSE ${PORT}
|
||||
ENTRYPOINT ["khoj"]
|
||||
ENTRYPOINT ["python3", "src/khoj/main.py"]
|
||||
|
|
|
@ -1,7 +1,21 @@
|
|||
version: "3.9"
|
||||
services:
|
||||
database:
|
||||
image: ankane/pgvector
|
||||
ports:
|
||||
- "5432:5432"
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: postgres
|
||||
volumes:
|
||||
- khoj_db:/var/lib/postgresql/data/
|
||||
server:
|
||||
# Use the following line to use the latest version of khoj. Otherwise, it will build from source.
|
||||
image: ghcr.io/khoj-ai/khoj:latest
|
||||
# Uncomment the following line to build from source. This will take a few minutes. Comment the next two lines out if you want to use the offiicial image.
|
||||
# build:
|
||||
# context: .
|
||||
ports:
|
||||
# If changing the local port (left hand side), no other changes required.
|
||||
# If changing the remote port (right hand side),
|
||||
|
@ -26,8 +40,15 @@ services:
|
|||
- ./tests/data/models/:/root/.khoj/search/
|
||||
- khoj_config:/root/.khoj/
|
||||
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
|
||||
environment:
|
||||
- POSTGRES_DB=postgres
|
||||
- POSTGRES_USER=postgres
|
||||
- POSTGRES_PASSWORD=postgres
|
||||
- POSTGRES_HOST=database
|
||||
- POSTGRES_PORT=5432
|
||||
command: --host="0.0.0.0" --port=42110 -vv
|
||||
|
||||
|
||||
volumes:
|
||||
khoj_config:
|
||||
khoj_db:
|
||||
|
|
|
@ -59,13 +59,15 @@ dependencies = [
|
|||
"requests >= 2.26.0",
|
||||
"bs4 >= 0.0.1",
|
||||
"anyio == 3.7.1",
|
||||
"pymupdf >= 1.23.3",
|
||||
"pymupdf >= 1.23.5",
|
||||
"django == 4.2.5",
|
||||
"authlib == 1.2.1",
|
||||
"gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'",
|
||||
"itsdangerous == 2.1.2",
|
||||
"httpx == 0.25.0",
|
||||
"pgvector == 0.2.3",
|
||||
"psycopg2-binary == 2.9.9",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
@ -91,6 +93,8 @@ dev = [
|
|||
"mypy >= 1.0.1",
|
||||
"black >= 23.1.0",
|
||||
"pre-commit >= 3.0.4",
|
||||
"pytest-django == 4.5.2",
|
||||
"pytest-asyncio == 0.21.1",
|
||||
]
|
||||
|
||||
[tool.hatch.version]
|
||||
|
|
4
pytest.ini
Normal file
4
pytest.ini
Normal file
|
@ -0,0 +1,4 @@
|
|||
[pytest]
|
||||
DJANGO_SETTINGS_MODULE = app.settings
|
||||
pythonpath = . src
|
||||
testpaths = tests
|
60
src/app/README.md
Normal file
60
src/app/README.md
Normal file
|
@ -0,0 +1,60 @@
|
|||
# Django App
|
||||
|
||||
Khoj uses Django as the backend framework primarily for its powerful ORM and the admin interface. The Django app is located in the `src/app` directory. We have one installed app, under the `/database/` directory. This app is responsible for all the database related operations and holds all of our models. You can find the extensive Django documentation [here](https://docs.djangoproject.com/en/4.2/) 🌈.
|
||||
|
||||
## Setup (Docker)
|
||||
|
||||
### Prerequisites
|
||||
1. Ensure you have [Docker](https://docs.docker.com/get-docker/) installed.
|
||||
2. Ensure you have [Docker Compose](https://docs.docker.com/compose/install/) installed.
|
||||
|
||||
### Run
|
||||
|
||||
Using the `docker-compose.yml` file in the root directory, you can run the Khoj app using the following command:
|
||||
```bash
|
||||
docker-compose up
|
||||
```
|
||||
|
||||
## Setup (Local)
|
||||
|
||||
### Install dependencies
|
||||
|
||||
```bash
|
||||
pip install -e '.[dev]'
|
||||
```
|
||||
|
||||
### Setup the database
|
||||
|
||||
1. Ensure you have Postgres installed. For MacOS, you can use [Postgres.app](https://postgresapp.com/).
|
||||
2. If you're not using Postgres.app, you may have to install the pgvector extension manually. You can find the instructions [here](https://github.com/pgvector/pgvector#installation). If you're using Postgres.app, you can skip this step. Reproduced instructions below for convenience.
|
||||
|
||||
```bash
|
||||
cd /tmp
|
||||
git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git
|
||||
cd pgvector
|
||||
make
|
||||
make install # may need sudo
|
||||
```
|
||||
3. Create a database
|
||||
|
||||
### Make migrations
|
||||
|
||||
This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted).
|
||||
|
||||
```bash
|
||||
python3 src/manage.py makemigrations
|
||||
```
|
||||
|
||||
### Run migrations
|
||||
|
||||
This command will run any pending migrations in your application.
|
||||
```bash
|
||||
python3 src/manage.py migrate
|
||||
```
|
||||
|
||||
### Run the server
|
||||
|
||||
While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend.
|
||||
```bash
|
||||
python3 src/khoj/main.py
|
||||
```
|
|
@ -77,8 +77,12 @@ WSGI_APPLICATION = "app.wsgi.application"
|
|||
|
||||
DATABASES = {
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.sqlite3",
|
||||
"NAME": BASE_DIR / "db.sqlite3",
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
"HOST": os.getenv("POSTGRES_HOST", "localhost"),
|
||||
"PORT": os.getenv("POSTGRES_PORT", "5432"),
|
||||
"USER": os.getenv("POSTGRES_USER", "postgres"),
|
||||
"NAME": os.getenv("POSTGRES_DB", "khoj"),
|
||||
"PASSWORD": os.getenv("POSTGRES_PASSWORD", "postgres"),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,15 +1,30 @@
|
|||
from typing import Type, TypeVar
|
||||
from typing import Type, TypeVar, List
|
||||
import uuid
|
||||
from datetime import date
|
||||
|
||||
from django.db import models
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
from pgvector.django import CosineDistance
|
||||
from django.db.models.manager import BaseManager
|
||||
from django.db.models import Q
|
||||
from torch import Tensor
|
||||
|
||||
# Import sync_to_async from Django Channels
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from database.models import KhojUser, GoogleUser, NotionConfig
|
||||
from database.models import (
|
||||
KhojUser,
|
||||
GoogleUser,
|
||||
NotionConfig,
|
||||
GithubConfig,
|
||||
Embeddings,
|
||||
GithubRepoConfig,
|
||||
)
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=models.Model)
|
||||
|
||||
|
@ -40,9 +55,7 @@ async def get_or_create_user(token: dict) -> KhojUser:
|
|||
|
||||
async def create_google_user(token: dict) -> KhojUser:
|
||||
user_info = token.get("userinfo")
|
||||
user = await KhojUser.objects.acreate(
|
||||
username=user_info.get("email"), email=user_info.get("email"), uuid=uuid.uuid4()
|
||||
)
|
||||
user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
|
||||
await user.asave()
|
||||
await GoogleUser.objects.acreate(
|
||||
sub=user_info.get("sub"),
|
||||
|
@ -76,3 +89,149 @@ async def retrieve_user(session_id: str) -> KhojUser:
|
|||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid user")
|
||||
return user
|
||||
|
||||
|
||||
def get_all_users() -> BaseManager[KhojUser]:
|
||||
return KhojUser.objects.all()
|
||||
|
||||
|
||||
def get_user_github_config(user: KhojUser):
|
||||
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||
if not config:
|
||||
return None
|
||||
return config
|
||||
|
||||
|
||||
def get_user_notion_config(user: KhojUser):
|
||||
config = NotionConfig.objects.filter(user=user).first()
|
||||
if not config:
|
||||
return None
|
||||
return config
|
||||
|
||||
|
||||
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
|
||||
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
|
||||
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
|
||||
await object.objects.filter(user=user).adelete()
|
||||
await object.objects.acreate(
|
||||
input_files=deduped_files,
|
||||
input_filter=deduped_filters,
|
||||
index_heading_entries=updated_config.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||
config = await GithubConfig.objects.filter(user=user).afirst()
|
||||
|
||||
if not config:
|
||||
config = await GithubConfig.objects.acreate(pat_token=pat_token, user=user)
|
||||
else:
|
||||
config.pat_token = pat_token
|
||||
await config.asave()
|
||||
await config.githubrepoconfig.all().adelete()
|
||||
|
||||
for repo in repos:
|
||||
await GithubRepoConfig.objects.acreate(
|
||||
name=repo["name"], owner=repo["owner"], branch=repo["branch"], github_config=config
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
class EmbeddingsAdapters:
|
||||
word_filer = WordFilter()
|
||||
file_filter = FileFilter()
|
||||
date_filter = DateFilter()
|
||||
|
||||
@staticmethod
|
||||
def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool:
|
||||
return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists()
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_file(user: KhojUser, file_path: str):
|
||||
deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete()
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def delete_all_embeddings(user: KhojUser, file_type: str):
|
||||
deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete()
|
||||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
||||
return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]):
|
||||
Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
||||
|
||||
@staticmethod
|
||||
def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date):
|
||||
return embeddings.filter(
|
||||
embeddingsdates__date__gte=start_date,
|
||||
embeddingsdates__date__lte=end_date,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def user_has_embeddings(user: KhojUser):
|
||||
return await Embeddings.objects.filter(user=user).aexists()
|
||||
|
||||
@staticmethod
|
||||
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
|
||||
q_filter_terms = Q()
|
||||
|
||||
explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query)
|
||||
file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query)
|
||||
date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query)
|
||||
|
||||
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
||||
return Embeddings.objects.filter(user=user)
|
||||
|
||||
for term in explicit_word_terms:
|
||||
if term.startswith("+"):
|
||||
q_filter_terms &= Q(raw__icontains=term[1:])
|
||||
elif term.startswith("-"):
|
||||
q_filter_terms &= ~Q(raw__icontains=term[1:])
|
||||
|
||||
q_file_filter_terms = Q()
|
||||
|
||||
if len(file_filters) > 0:
|
||||
for term in file_filters:
|
||||
q_file_filter_terms |= Q(file_path__regex=term)
|
||||
|
||||
q_filter_terms &= q_file_filter_terms
|
||||
|
||||
if len(date_filters) > 0:
|
||||
min_date, max_date = date_filters
|
||||
if min_date is not None:
|
||||
# Convert the min_date timestamp to yyyy-mm-dd format
|
||||
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
|
||||
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
|
||||
if max_date is not None:
|
||||
# Convert the max_date timestamp to yyyy-mm-dd format
|
||||
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
|
||||
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
|
||||
|
||||
relevant_embeddings = Embeddings.objects.filter(user=user).filter(
|
||||
q_filter_terms,
|
||||
)
|
||||
if file_type_filter:
|
||||
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
|
||||
return relevant_embeddings
|
||||
|
||||
@staticmethod
|
||||
def search_with_embeddings(
|
||||
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
|
||||
):
|
||||
relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter)
|
||||
relevant_embeddings = relevant_embeddings.filter(user=user).annotate(
|
||||
distance=CosineDistance("embeddings", embeddings)
|
||||
)
|
||||
if file_type_filter:
|
||||
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
|
||||
relevant_embeddings = relevant_embeddings.order_by("distance")
|
||||
return relevant_embeddings[:max_results]
|
||||
|
||||
@staticmethod
|
||||
def get_unique_file_types(user: KhojUser):
|
||||
return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
# Generated by Django 4.2.5 on 2023-09-27 17:52
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
import uuid
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0002_googleuser"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="Configuration",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
],
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="ConversationProcessorConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("conversation", models.JSONField()),
|
||||
("enable_offline_chat", models.BooleanField(default=False)),
|
||||
],
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="GithubConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("pat_token", models.CharField(max_length=200)),
|
||||
("compressed_jsonl", models.CharField(max_length=300)),
|
||||
("embeddings_file", models.CharField(max_length=300)),
|
||||
(
|
||||
"config",
|
||||
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"),
|
||||
),
|
||||
],
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="khojuser",
|
||||
name="uuid",
|
||||
field=models.UUIDField(verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="NotionConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("token", models.CharField(max_length=200)),
|
||||
("compressed_jsonl", models.CharField(max_length=300)),
|
||||
("embeddings_file", models.CharField(max_length=300)),
|
||||
(
|
||||
"config",
|
||||
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"),
|
||||
),
|
||||
],
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="GithubRepoConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("name", models.CharField(max_length=200)),
|
||||
("owner", models.CharField(max_length=200)),
|
||||
("branch", models.CharField(max_length=200)),
|
||||
(
|
||||
"github_config",
|
||||
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.githubconfig"),
|
||||
),
|
||||
],
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="configuration",
|
||||
name="user",
|
||||
field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
]
|
10
src/database/migrations/0003_vector_extension.py
Normal file
10
src/database/migrations/0003_vector_extension.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from django.db import migrations
|
||||
from pgvector.django import VectorExtension
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0002_googleuser"),
|
||||
]
|
||||
|
||||
operations = [VectorExtension()]
|
|
@ -0,0 +1,193 @@
|
|||
# Generated by Django 4.2.5 on 2023-10-11 22:24
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
import pgvector.django
|
||||
import uuid
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0003_vector_extension"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="ConversationProcessorConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("conversation", models.JSONField()),
|
||||
("enable_offline_chat", models.BooleanField(default=False)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="GithubConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("pat_token", models.CharField(max_length=200)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="khojuser",
|
||||
name="uuid",
|
||||
field=models.UUIDField(default=1234, verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="NotionConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("token", models.CharField(max_length=200)),
|
||||
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="LocalPlaintextConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("input_files", models.JSONField(default=list, null=True)),
|
||||
("input_filter", models.JSONField(default=list, null=True)),
|
||||
("index_heading_entries", models.BooleanField(default=False)),
|
||||
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="LocalPdfConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("input_files", models.JSONField(default=list, null=True)),
|
||||
("input_filter", models.JSONField(default=list, null=True)),
|
||||
("index_heading_entries", models.BooleanField(default=False)),
|
||||
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="LocalOrgConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("input_files", models.JSONField(default=list, null=True)),
|
||||
("input_filter", models.JSONField(default=list, null=True)),
|
||||
("index_heading_entries", models.BooleanField(default=False)),
|
||||
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="LocalMarkdownConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("input_files", models.JSONField(default=list, null=True)),
|
||||
("input_filter", models.JSONField(default=list, null=True)),
|
||||
("index_heading_entries", models.BooleanField(default=False)),
|
||||
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="GithubRepoConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("name", models.CharField(max_length=200)),
|
||||
("owner", models.CharField(max_length=200)),
|
||||
("branch", models.CharField(max_length=200)),
|
||||
(
|
||||
"github_config",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="githubrepoconfig",
|
||||
to="database.githubconfig",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="githubconfig",
|
||||
name="user",
|
||||
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="Embeddings",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("embeddings", pgvector.django.VectorField(dimensions=384)),
|
||||
("raw", models.TextField()),
|
||||
("compiled", models.TextField()),
|
||||
("heading", models.CharField(blank=True, default=None, max_length=1000, null=True)),
|
||||
(
|
||||
"file_type",
|
||||
models.CharField(
|
||||
choices=[
|
||||
("image", "Image"),
|
||||
("pdf", "Pdf"),
|
||||
("plaintext", "Plaintext"),
|
||||
("markdown", "Markdown"),
|
||||
("org", "Org"),
|
||||
("notion", "Notion"),
|
||||
("github", "Github"),
|
||||
("conversation", "Conversation"),
|
||||
],
|
||||
default="plaintext",
|
||||
max_length=30,
|
||||
),
|
||||
),
|
||||
("file_path", models.CharField(blank=True, default=None, max_length=400, null=True)),
|
||||
("file_name", models.CharField(blank=True, default=None, max_length=400, null=True)),
|
||||
("url", models.URLField(blank=True, default=None, max_length=400, null=True)),
|
||||
("hashed_value", models.CharField(max_length=100)),
|
||||
(
|
||||
"user",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
18
src/database/migrations/0005_embeddings_corpus_id.py
Normal file
18
src/database/migrations/0005_embeddings_corpus_id.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.2.5 on 2023-10-13 02:39
|
||||
|
||||
from django.db import migrations, models
|
||||
import uuid
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0004_conversationprocessorconfig_githubconfig_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="embeddings",
|
||||
name="corpus_id",
|
||||
field=models.UUIDField(default=uuid.uuid4, editable=False),
|
||||
),
|
||||
]
|
33
src/database/migrations/0006_embeddingsdates.py
Normal file
33
src/database/migrations/0006_embeddingsdates.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Generated by Django 4.2.5 on 2023-10-13 19:28
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0005_embeddings_corpus_id"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="EmbeddingsDates",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("date", models.DateField()),
|
||||
(
|
||||
"embeddings",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="embeddings_dates",
|
||||
to="database.embeddings",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"indexes": [models.Index(fields=["date"], name="database_em_date_a1ba47_idx")],
|
||||
},
|
||||
),
|
||||
]
|
|
@ -2,11 +2,25 @@ import uuid
|
|||
|
||||
from django.db import models
|
||||
from django.contrib.auth.models import AbstractUser
|
||||
from pgvector.django import VectorField
|
||||
|
||||
|
||||
class BaseModel(models.Model):
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class KhojUser(AbstractUser):
|
||||
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if not self.uuid:
|
||||
self.uuid = uuid.uuid4()
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class GoogleUser(models.Model):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
|
@ -23,31 +37,85 @@ class GoogleUser(models.Model):
|
|||
return self.name
|
||||
|
||||
|
||||
class Configuration(models.Model):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class NotionConfig(models.Model):
|
||||
class NotionConfig(BaseModel):
|
||||
token = models.CharField(max_length=200)
|
||||
compressed_jsonl = models.CharField(max_length=300)
|
||||
embeddings_file = models.CharField(max_length=300)
|
||||
config = models.OneToOneField(Configuration, on_delete=models.CASCADE)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class GithubConfig(models.Model):
|
||||
class GithubConfig(BaseModel):
|
||||
pat_token = models.CharField(max_length=200)
|
||||
compressed_jsonl = models.CharField(max_length=300)
|
||||
embeddings_file = models.CharField(max_length=300)
|
||||
config = models.OneToOneField(Configuration, on_delete=models.CASCADE)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class GithubRepoConfig(models.Model):
|
||||
class GithubRepoConfig(BaseModel):
|
||||
name = models.CharField(max_length=200)
|
||||
owner = models.CharField(max_length=200)
|
||||
branch = models.CharField(max_length=200)
|
||||
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE)
|
||||
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
|
||||
|
||||
|
||||
class ConversationProcessorConfig(models.Model):
|
||||
class LocalOrgConfig(BaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
input_filter = models.JSONField(default=list, null=True)
|
||||
index_heading_entries = models.BooleanField(default=False)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class LocalMarkdownConfig(BaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
input_filter = models.JSONField(default=list, null=True)
|
||||
index_heading_entries = models.BooleanField(default=False)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class LocalPdfConfig(BaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
input_filter = models.JSONField(default=list, null=True)
|
||||
index_heading_entries = models.BooleanField(default=False)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class LocalPlaintextConfig(BaseModel):
|
||||
input_files = models.JSONField(default=list, null=True)
|
||||
input_filter = models.JSONField(default=list, null=True)
|
||||
index_heading_entries = models.BooleanField(default=False)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class ConversationProcessorConfig(BaseModel):
|
||||
conversation = models.JSONField()
|
||||
enable_offline_chat = models.BooleanField(default=False)
|
||||
|
||||
|
||||
class Embeddings(BaseModel):
|
||||
class EmbeddingsType(models.TextChoices):
|
||||
IMAGE = "image"
|
||||
PDF = "pdf"
|
||||
PLAINTEXT = "plaintext"
|
||||
MARKDOWN = "markdown"
|
||||
ORG = "org"
|
||||
NOTION = "notion"
|
||||
GITHUB = "github"
|
||||
CONVERSATION = "conversation"
|
||||
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
embeddings = VectorField(dimensions=384)
|
||||
raw = models.TextField()
|
||||
compiled = models.TextField()
|
||||
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
|
||||
file_type = models.CharField(max_length=30, choices=EmbeddingsType.choices, default=EmbeddingsType.PLAINTEXT)
|
||||
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
url = models.URLField(max_length=400, default=None, null=True, blank=True)
|
||||
hashed_value = models.CharField(max_length=100)
|
||||
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
||||
|
||||
|
||||
class EmbeddingsDates(BaseModel):
|
||||
date = models.DateField()
|
||||
embeddings = models.ForeignKey(Embeddings, on_delete=models.CASCADE, related_name="embeddings_dates")
|
||||
|
||||
class Meta:
|
||||
indexes = [
|
||||
models.Index(fields=["date"]),
|
||||
]
|
||||
|
|
3
src/database/tests.py
Normal file
3
src/database/tests.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
|
@ -30,6 +30,8 @@ from khoj.utils.helpers import resolve_absolute_path, merge_dicts
|
|||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
|
||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||
from database.models import KhojUser
|
||||
from database.adapters import get_all_users
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -48,14 +50,28 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
from database.models import KhojUser
|
||||
|
||||
self.khojuser_manager = KhojUser.objects
|
||||
self._initialize_default_user()
|
||||
super().__init__()
|
||||
|
||||
def _initialize_default_user(self):
|
||||
if not self.khojuser_manager.filter(username="default").exists():
|
||||
self.khojuser_manager.create_user(
|
||||
username="default",
|
||||
email="default@example.com",
|
||||
password="default",
|
||||
)
|
||||
|
||||
async def authenticate(self, request):
|
||||
current_user = request.session.get("user")
|
||||
if current_user and current_user.get("email"):
|
||||
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
|
||||
if user:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
elif not state.anonymous_mode:
|
||||
user = await self.khojuser_manager.filter(username="default").afirst()
|
||||
if user:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
|
||||
return AuthCredentials(), UnauthenticatedUser()
|
||||
|
||||
|
||||
|
@ -78,7 +94,11 @@ def initialize_server(config: Optional[FullConfig]):
|
|||
|
||||
|
||||
def configure_server(
|
||||
config: FullConfig, regenerate: bool = False, search_type: Optional[SearchType] = None, init=False
|
||||
config: FullConfig,
|
||||
regenerate: bool = False,
|
||||
search_type: Optional[SearchType] = None,
|
||||
init=False,
|
||||
user: KhojUser = None,
|
||||
):
|
||||
# Update Config
|
||||
state.config = config
|
||||
|
@ -95,7 +115,7 @@ def configure_server(
|
|||
state.config_lock.acquire()
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
initialize_content(regenerate, search_type, init)
|
||||
initialize_content(regenerate, search_type, init, user)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to configure search models", exc_info=True)
|
||||
raise e
|
||||
|
@ -103,7 +123,7 @@ def configure_server(
|
|||
state.config_lock.release()
|
||||
|
||||
|
||||
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False):
|
||||
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
|
||||
# Initialize Content from Config
|
||||
if state.search_models:
|
||||
try:
|
||||
|
@ -112,7 +132,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
|
|||
state.content_index = load_content(state.config.content_type, state.content_index, state.search_models)
|
||||
else:
|
||||
logger.info("📬 Updating content index...")
|
||||
all_files = collect_files(state.config.content_type)
|
||||
all_files = collect_files(user=user)
|
||||
state.content_index = configure_content(
|
||||
state.content_index,
|
||||
state.config.content_type,
|
||||
|
@ -120,6 +140,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
|
|||
state.search_models,
|
||||
regenerate,
|
||||
search_type,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to index content", exc_info=True)
|
||||
|
@ -152,9 +173,14 @@ if not state.demo:
|
|||
def update_search_index():
|
||||
try:
|
||||
logger.info("📬 Updating content index via Scheduler")
|
||||
all_files = collect_files(state.config.content_type)
|
||||
for user in get_all_users():
|
||||
all_files = collect_files(user=user)
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=user
|
||||
)
|
||||
all_files = collect_files(user=None)
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=None
|
||||
)
|
||||
logger.info("📪 Content index updated via Scheduler")
|
||||
except Exception as e:
|
||||
|
@ -164,13 +190,9 @@ if not state.demo:
|
|||
def configure_search_types(config: FullConfig):
|
||||
# Extract core search types
|
||||
core_search_types = {e.name: e.value for e in SearchType}
|
||||
# Extract configured plugin search types
|
||||
plugin_search_types = {}
|
||||
if config.content_type and config.content_type.plugins:
|
||||
plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()}
|
||||
|
||||
# Dynamically generate search type enum by merging core search types with configured plugin search types
|
||||
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
||||
return Enum("SearchType", merge_dicts(core_search_types, {}))
|
||||
|
||||
|
||||
def configure_processor(
|
||||
|
|
|
@ -10,12 +10,10 @@
|
|||
<img class="card-icon" src="/static/assets/icons/github.svg" alt="Github">
|
||||
<h3 class="card-title">
|
||||
Github
|
||||
{% if current_config.content_type.github %}
|
||||
{% if current_model_state.github == False %}
|
||||
<img id="misconfigured-icon-github" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||
{% else %}
|
||||
<img id="configured-icon-github" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
{% endif %}
|
||||
{% if current_model_state.github == False %}
|
||||
<img id="misconfigured-icon-github" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||
{% else %}
|
||||
<img id="configured-icon-github" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
{% endif %}
|
||||
</h3>
|
||||
</div>
|
||||
|
@ -24,7 +22,7 @@
|
|||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content_type/github">
|
||||
{% if current_config.content_type.github %}
|
||||
{% if current_model_state.github %}
|
||||
Update
|
||||
{% else %}
|
||||
Setup
|
||||
|
@ -32,7 +30,7 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
{% if current_config.content_type.github %}
|
||||
{% if current_model_state.github %}
|
||||
<div id="clear-github" class="card-action-row">
|
||||
<button class="card-button" onclick="clearContentType('github')">
|
||||
Disable
|
||||
|
@ -45,12 +43,10 @@
|
|||
<img class="card-icon" src="/static/assets/icons/notion.svg" alt="Notion">
|
||||
<h3 class="card-title">
|
||||
Notion
|
||||
{% if current_config.content_type.notion %}
|
||||
{% if current_model_state.notion == False %}
|
||||
<img id="misconfigured-icon-notion" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||
{% else %}
|
||||
<img id="configured-icon-notion" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
{% endif %}
|
||||
{% if current_model_state.notion == False %}
|
||||
<img id="misconfigured-icon-notion" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||
{% else %}
|
||||
<img id="configured-icon-notion" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
{% endif %}
|
||||
</h3>
|
||||
</div>
|
||||
|
@ -59,7 +55,7 @@
|
|||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content_type/notion">
|
||||
{% if current_config.content_type.content %}
|
||||
{% if current_model_state.content %}
|
||||
Update
|
||||
{% else %}
|
||||
Setup
|
||||
|
@ -67,7 +63,7 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
{% if current_config.content_type.notion %}
|
||||
{% if current_model_state.notion %}
|
||||
<div id="clear-notion" class="card-action-row">
|
||||
<button class="card-button" onclick="clearContentType('notion')">
|
||||
Disable
|
||||
|
@ -80,7 +76,7 @@
|
|||
<img class="card-icon" src="/static/assets/icons/markdown.svg" alt="markdown">
|
||||
<h3 class="card-title">
|
||||
Markdown
|
||||
{% if current_config.content_type.markdown %}
|
||||
{% if current_model_state.markdown %}
|
||||
{% if current_model_state.markdown == False%}
|
||||
<img id="misconfigured-icon-markdown" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||
{% else %}
|
||||
|
@ -94,7 +90,7 @@
|
|||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content_type/markdown">
|
||||
{% if current_config.content_type.markdown %}
|
||||
{% if current_model_state.markdown %}
|
||||
Update
|
||||
{% else %}
|
||||
Setup
|
||||
|
@ -102,7 +98,7 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
{% if current_config.content_type.markdown %}
|
||||
{% if current_model_state.markdown %}
|
||||
<div id="clear-markdown" class="card-action-row">
|
||||
<button class="card-button" onclick="clearContentType('markdown')">
|
||||
Disable
|
||||
|
@ -115,7 +111,7 @@
|
|||
<img class="card-icon" src="/static/assets/icons/org.svg" alt="org">
|
||||
<h3 class="card-title">
|
||||
Org
|
||||
{% if current_config.content_type.org %}
|
||||
{% if current_model_state.org %}
|
||||
{% if current_model_state.org == False %}
|
||||
<img id="misconfigured-icon-org" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||
{% else %}
|
||||
|
@ -129,7 +125,7 @@
|
|||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content_type/org">
|
||||
{% if current_config.content_type.org %}
|
||||
{% if current_model_state.org %}
|
||||
Update
|
||||
{% else %}
|
||||
Setup
|
||||
|
@ -137,7 +133,7 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
{% if current_config.content_type.org %}
|
||||
{% if current_model_state.org %}
|
||||
<div id="clear-org" class="card-action-row">
|
||||
<button class="card-button" onclick="clearContentType('org')">
|
||||
Disable
|
||||
|
@ -150,7 +146,7 @@
|
|||
<img class="card-icon" src="/static/assets/icons/pdf.svg" alt="PDF">
|
||||
<h3 class="card-title">
|
||||
PDF
|
||||
{% if current_config.content_type.pdf %}
|
||||
{% if current_model_state.pdf %}
|
||||
{% if current_model_state.pdf == False %}
|
||||
<img id="misconfigured-icon-pdf" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you need to click Configure.">
|
||||
{% else %}
|
||||
|
@ -164,7 +160,7 @@
|
|||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content_type/pdf">
|
||||
{% if current_config.content_type.pdf %}
|
||||
{% if current_model_state.pdf %}
|
||||
Update
|
||||
{% else %}
|
||||
Setup
|
||||
|
@ -172,7 +168,7 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
{% if current_config.content_type.pdf %}
|
||||
{% if current_model_state.pdf %}
|
||||
<div id="clear-pdf" class="card-action-row">
|
||||
<button class="card-button" onclick="clearContentType('pdf')">
|
||||
Disable
|
||||
|
@ -185,7 +181,7 @@
|
|||
<img class="card-icon" src="/static/assets/icons/plaintext.svg" alt="Plaintext">
|
||||
<h3 class="card-title">
|
||||
Plaintext
|
||||
{% if current_config.content_type.plaintext %}
|
||||
{% if current_model_state.plaintext %}
|
||||
{% if current_model_state.plaintext == False %}
|
||||
<img id="misconfigured-icon-plaintext" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you need to click Configure.">
|
||||
{% else %}
|
||||
|
@ -199,7 +195,7 @@
|
|||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content_type/plaintext">
|
||||
{% if current_config.content_type.plaintext %}
|
||||
{% if current_model_state.plaintext %}
|
||||
Update
|
||||
{% else %}
|
||||
Setup
|
||||
|
@ -207,7 +203,7 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
{% if current_config.content_type.plaintext %}
|
||||
{% if current_model_state.plaintext %}
|
||||
<div id="clear-plaintext" class="card-action-row">
|
||||
<button class="card-button" onclick="clearContentType('plaintext')">
|
||||
Disable
|
||||
|
|
|
@ -38,24 +38,6 @@
|
|||
{% endfor %}
|
||||
</div>
|
||||
<button type="button" id="add-repository-button">Add Repository</button>
|
||||
<table style="display: none;" >
|
||||
<tr>
|
||||
<td>
|
||||
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
|
||||
</td>
|
||||
<td>
|
||||
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<label for="embeddings-file">Embeddings File (Output)</label>
|
||||
</td>
|
||||
<td>
|
||||
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<div class="section">
|
||||
<div id="success" style="display: none;"></div>
|
||||
<button id="submit" type="submit">Save</button>
|
||||
|
@ -107,8 +89,6 @@
|
|||
submit.addEventListener("click", function(event) {
|
||||
event.preventDefault();
|
||||
|
||||
const compressed_jsonl = document.getElementById("compressed-jsonl").value;
|
||||
const embeddings_file = document.getElementById("embeddings-file").value;
|
||||
const pat_token = document.getElementById("pat-token").value;
|
||||
|
||||
if (pat_token == "") {
|
||||
|
@ -154,8 +134,6 @@
|
|||
body: JSON.stringify({
|
||||
"pat_token": pat_token,
|
||||
"repos": repos,
|
||||
"compressed_jsonl": compressed_jsonl,
|
||||
"embeddings_file": embeddings_file,
|
||||
})
|
||||
})
|
||||
.then(response => response.json())
|
||||
|
|
|
@ -43,33 +43,6 @@
|
|||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<table style="display: none;" >
|
||||
<tr>
|
||||
<td>
|
||||
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
|
||||
</td>
|
||||
<td>
|
||||
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<label for="embeddings-file">Embeddings File (Output)</label>
|
||||
</td>
|
||||
<td>
|
||||
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<label for="index-heading-entries">Index Heading Entries</label>
|
||||
</td>
|
||||
<td>
|
||||
<input type="text" id="index-heading-entries" name="index-heading-entries" value="{{ current_config['index_heading_entries'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<div class="section">
|
||||
<div id="success" style="display: none;" ></div>
|
||||
<button id="submit" type="submit">Save</button>
|
||||
|
@ -155,9 +128,8 @@
|
|||
inputFilter = null;
|
||||
}
|
||||
|
||||
var compressed_jsonl = document.getElementById("compressed-jsonl").value;
|
||||
var embeddings_file = document.getElementById("embeddings-file").value;
|
||||
var index_heading_entries = document.getElementById("index-heading-entries").value;
|
||||
// var index_heading_entries = document.getElementById("index-heading-entries").value;
|
||||
var index_heading_entries = true;
|
||||
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
fetch('/api/config/data/content_type/{{ content_type }}', {
|
||||
|
@ -169,8 +141,6 @@
|
|||
body: JSON.stringify({
|
||||
"input_files": inputFiles,
|
||||
"input_filter": inputFilter,
|
||||
"compressed_jsonl": compressed_jsonl,
|
||||
"embeddings_file": embeddings_file,
|
||||
"index_heading_entries": index_heading_entries
|
||||
})
|
||||
})
|
||||
|
|
|
@ -20,24 +20,6 @@
|
|||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<table style="display: none;" >
|
||||
<tr>
|
||||
<td>
|
||||
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
|
||||
</td>
|
||||
<td>
|
||||
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<label for="embeddings-file">Embeddings File (Output)</label>
|
||||
</td>
|
||||
<td>
|
||||
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<div class="section">
|
||||
<div id="success" style="display: none;"></div>
|
||||
<button id="submit" type="submit">Save</button>
|
||||
|
@ -51,8 +33,6 @@
|
|||
submit.addEventListener("click", function(event) {
|
||||
event.preventDefault();
|
||||
|
||||
const compressed_jsonl = document.getElementById("compressed-jsonl").value;
|
||||
const embeddings_file = document.getElementById("embeddings-file").value;
|
||||
const token = document.getElementById("token").value;
|
||||
|
||||
if (token == "") {
|
||||
|
@ -70,8 +50,6 @@
|
|||
},
|
||||
body: JSON.stringify({
|
||||
"token": token,
|
||||
"compressed_jsonl": compressed_jsonl,
|
||||
"embeddings_file": embeddings_file,
|
||||
})
|
||||
})
|
||||
.then(response => response.json())
|
||||
|
|
|
@ -172,7 +172,7 @@
|
|||
url = createRequestUrl(query, type, results_count || 5, rerank);
|
||||
fetch(url, {
|
||||
headers: {
|
||||
"X-CSRFToken": csrfToken
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
})
|
||||
.then(response => response.json())
|
||||
|
@ -199,8 +199,8 @@
|
|||
fetch("/api/config/types")
|
||||
.then(response => response.json())
|
||||
.then(enabled_types => {
|
||||
// Show warning if no content types are enabled
|
||||
if (enabled_types.detail) {
|
||||
// Show warning if no content types are enabled, or just one ("all")
|
||||
if (enabled_types[0] === "all" && enabled_types.length === 1) {
|
||||
document.getElementById("results").innerHTML = "<div id='results-error'>To use Khoj search, setup your content plugins on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>.</div>";
|
||||
document.getElementById("query").setAttribute("disabled", "disabled");
|
||||
document.getElementById("query").setAttribute("placeholder", "Configure Khoj to enable search");
|
||||
|
|
|
@ -24,10 +24,15 @@ from rich.logging import RichHandler
|
|||
from django.core.asgi import get_asgi_application
|
||||
from django.core.management import call_command
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_routes, initialize_server, configure_middleware
|
||||
from khoj.utils import state
|
||||
from khoj.utils.cli import cli
|
||||
# Initialize Django
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
|
||||
django.setup()
|
||||
|
||||
# Initialize Django Database
|
||||
call_command("migrate", "--noinput")
|
||||
|
||||
# Initialize Django Static Files
|
||||
call_command("collectstatic", "--noinput")
|
||||
|
||||
# Initialize Django
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
|
||||
|
@ -54,6 +59,11 @@ app.add_middleware(
|
|||
# Set Locale
|
||||
locale.setlocale(locale.LC_ALL, "")
|
||||
|
||||
# Internal Packages. We do this after setting up Django so that Django features are accessible to the app.
|
||||
from khoj.configure import configure_routes, initialize_server, configure_middleware
|
||||
from khoj.utils import state
|
||||
from khoj.utils.cli import cli
|
||||
|
||||
# Setup Logger
|
||||
rich_handler = RichHandler(rich_tracebacks=True)
|
||||
rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
|
||||
|
@ -95,6 +105,8 @@ def run():
|
|||
|
||||
# Mount Django and Static Files
|
||||
app.mount("/django", django_app, name="django")
|
||||
if not os.path.exists("static"):
|
||||
os.mkdir("static")
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
# Configure Middleware
|
||||
|
@ -111,6 +123,7 @@ def set_state(args):
|
|||
state.host = args.host
|
||||
state.port = args.port
|
||||
state.demo = args.demo
|
||||
state.anonymous_mode = args.anonymous_mode
|
||||
state.khoj_version = version("khoj-assistant")
|
||||
|
||||
|
57
src/khoj/processor/embeddings.py
Normal file
57
src/khoj/processor/embeddings.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
|
||||
|
||||
class EmbeddingsModel:
|
||||
def __init__(self):
|
||||
self.model_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
|
||||
encode_kwargs = {"normalize_embeddings": True}
|
||||
# encode_kwargs = {}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Use CUDA GPU
|
||||
device = torch.device("cuda:0")
|
||||
elif torch.backends.mps.is_available():
|
||||
# Use Apple M1 Metal Acceleration
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
model_kwargs = {"device": device}
|
||||
self.embeddings_model = HuggingFaceEmbeddings(
|
||||
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
|
||||
)
|
||||
|
||||
def embed_query(self, query):
|
||||
return self.embeddings_model.embed_query(query)
|
||||
|
||||
def embed_documents(self, docs):
|
||||
return self.embeddings_model.embed_documents(docs)
|
||||
|
||||
|
||||
class CrossEncoderModel:
|
||||
def __init__(self):
|
||||
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Use CUDA GPU
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
elif torch.backends.mps.is_available():
|
||||
# Use Apple M1 Metal Acceleration
|
||||
device = torch.device("mps")
|
||||
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=device)
|
||||
|
||||
def predict(self, query, hits: List[SearchResponse]):
|
||||
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
|
||||
cross_scores = self.cross_encoder_model.predict(cross__inp)
|
||||
return cross_scores
|
|
@ -2,7 +2,7 @@
|
|||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Union, Tuple
|
||||
|
||||
# External Packages
|
||||
import requests
|
||||
|
@ -12,18 +12,31 @@ from khoj.utils.helpers import timer
|
|||
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from database.models import Embeddings, GithubConfig, KhojUser
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GithubToJsonl(TextToJsonl):
|
||||
def __init__(self, config: GithubContentConfig):
|
||||
class GithubToJsonl(TextEmbeddings):
|
||||
def __init__(self, config: GithubConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
raw_repos = config.githubrepoconfig.all()
|
||||
repos = []
|
||||
for repo in raw_repos:
|
||||
repos.append(
|
||||
GithubRepoConfig(
|
||||
name=repo.name,
|
||||
owner=repo.owner,
|
||||
branch=repo.branch,
|
||||
)
|
||||
)
|
||||
self.config = GithubContentConfig(
|
||||
pat_token=config.pat_token,
|
||||
repos=repos,
|
||||
)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({"Authorization": f"token {self.config.pat_token}"})
|
||||
|
||||
|
@ -37,7 +50,9 @@ class GithubToJsonl(TextToJsonl):
|
|||
else:
|
||||
return
|
||||
|
||||
def process(self, previous_entries=[], files=None, full_corpus=True):
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
if self.config.pat_token is None or self.config.pat_token == "":
|
||||
logger.error(f"Github PAT token is not set. Skipping github content")
|
||||
raise ValueError("Github PAT token is not set. Skipping github content")
|
||||
|
@ -45,7 +60,7 @@ class GithubToJsonl(TextToJsonl):
|
|||
for repo in self.config.repos:
|
||||
current_entries += self.process_repo(repo)
|
||||
|
||||
return self.update_entries_with_ids(current_entries, previous_entries)
|
||||
return self.update_entries_with_ids(current_entries, user=user)
|
||||
|
||||
def process_repo(self, repo: GithubRepoConfig):
|
||||
repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}"
|
||||
|
@ -80,26 +95,18 @@ class GithubToJsonl(TextToJsonl):
|
|||
current_entries += issue_entries
|
||||
|
||||
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
|
||||
current_entries = TextToJsonl.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
current_entries = TextEmbeddings.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
|
||||
return current_entries
|
||||
|
||||
def update_entries_with_ids(self, current_entries, previous_entries):
|
||||
def update_entries_with_ids(self, current_entries, user: KhojUser = None):
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||
current_entries, previous_entries, key="compiled", logger=logger
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries, Embeddings.EmbeddingsType.GITHUB, key="compiled", logger=logger, user=user
|
||||
)
|
||||
|
||||
with timer("Write github entries to JSONL file", logger):
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
|
||||
|
||||
return entries_with_ids
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
def get_files(self, repo_url: str, repo: GithubRepoConfig):
|
||||
# Get the contents of the repository
|
||||
|
|
|
@ -1,91 +0,0 @@
|
|||
# Standard Packages
|
||||
import glob
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.helpers import get_absolute_path, timer
|
||||
from khoj.utils.jsonl import load_jsonl, compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JsonlToJsonl(TextToJsonl):
|
||||
# Define Functions
|
||||
def process(self, previous_entries=[], files: dict[str, str] = {}, full_corpus: bool = True):
|
||||
# Extract required fields from config
|
||||
input_jsonl_files, input_jsonl_filter, output_file = (
|
||||
self.config.input_files,
|
||||
self.config.input_filter,
|
||||
self.config.compressed_jsonl,
|
||||
)
|
||||
|
||||
# Get Jsonl Input Files to Process
|
||||
all_input_jsonl_files = JsonlToJsonl.get_jsonl_files(input_jsonl_files, input_jsonl_filter)
|
||||
|
||||
# Extract Entries from specified jsonl files
|
||||
with timer("Parse entries from jsonl files", logger):
|
||||
input_jsons = JsonlToJsonl.extract_jsonl_entries(all_input_jsonl_files)
|
||||
current_entries = list(map(Entry.from_dict, input_jsons))
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||
current_entries, previous_entries, key="compiled", logger=logger
|
||||
)
|
||||
|
||||
with timer("Write entries to JSONL file", logger):
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||
jsonl_data = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
|
||||
return entries_with_ids
|
||||
|
||||
@staticmethod
|
||||
def get_jsonl_files(jsonl_files=None, jsonl_file_filters=None):
|
||||
"Get all jsonl files to process"
|
||||
absolute_jsonl_files, filtered_jsonl_files = set(), set()
|
||||
if jsonl_files:
|
||||
absolute_jsonl_files = {get_absolute_path(jsonl_file) for jsonl_file in jsonl_files}
|
||||
if jsonl_file_filters:
|
||||
filtered_jsonl_files = {
|
||||
filtered_file
|
||||
for jsonl_file_filter in jsonl_file_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(jsonl_file_filter), recursive=True)
|
||||
}
|
||||
|
||||
all_jsonl_files = sorted(absolute_jsonl_files | filtered_jsonl_files)
|
||||
|
||||
files_with_non_jsonl_extensions = {
|
||||
jsonl_file for jsonl_file in all_jsonl_files if not jsonl_file.endswith(".jsonl")
|
||||
}
|
||||
if any(files_with_non_jsonl_extensions):
|
||||
print(f"[Warning] There maybe non jsonl files in the input set: {files_with_non_jsonl_extensions}")
|
||||
|
||||
logger.debug(f"Processing files: {all_jsonl_files}")
|
||||
|
||||
return all_jsonl_files
|
||||
|
||||
@staticmethod
|
||||
def extract_jsonl_entries(jsonl_files):
|
||||
"Extract entries from specified jsonl files"
|
||||
entries = []
|
||||
for jsonl_file in jsonl_files:
|
||||
entries.extend(load_jsonl(Path(jsonl_file)))
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_entries_to_jsonl(entries: List[Entry]):
|
||||
"Convert each entry to JSON and collate as JSONL"
|
||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
|
@ -3,29 +3,28 @@ import logging
|
|||
import re
|
||||
import urllib3
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Tuple, List
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from database.models import Embeddings, KhojUser
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarkdownToJsonl(TextToJsonl):
|
||||
def __init__(self, config: TextContentConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
class MarkdownToJsonl(TextEmbeddings):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(self, previous_entries=[], files=None, full_corpus: bool = True):
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
output_file = self.config.compressed_jsonl
|
||||
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
|
@ -45,19 +44,17 @@ class MarkdownToJsonl(TextToJsonl):
|
|||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
Embeddings.EmbeddingsType.MARKDOWN,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
)
|
||||
|
||||
with timer("Write markdown entries to JSONL file", logger):
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
|
||||
return entries_with_ids
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_markdown_entries(markdown_files):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Standard Packages
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
# External Packages
|
||||
import requests
|
||||
|
@ -7,9 +8,9 @@ import requests
|
|||
# Internal Packages
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import Entry, NotionContentConfig
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from database.models import Embeddings, KhojUser, NotionConfig
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
@ -49,10 +50,12 @@ class NotionBlockType(Enum):
|
|||
CALLOUT = "callout"
|
||||
|
||||
|
||||
class NotionToJsonl(TextToJsonl):
|
||||
def __init__(self, config: NotionContentConfig):
|
||||
class NotionToJsonl(TextEmbeddings):
|
||||
def __init__(self, config: NotionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.config = NotionContentConfig(
|
||||
token=config.token,
|
||||
)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({"Authorization": f"Bearer {config.token}", "Notion-Version": "2022-02-22"})
|
||||
self.unsupported_block_types = [
|
||||
|
@ -80,7 +83,9 @@ class NotionToJsonl(TextToJsonl):
|
|||
|
||||
self.body_params = {"page_size": 100}
|
||||
|
||||
def process(self, previous_entries=[], files=None, full_corpus=True):
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
current_entries = []
|
||||
|
||||
# Get all pages
|
||||
|
@ -112,7 +117,7 @@ class NotionToJsonl(TextToJsonl):
|
|||
page_entries = self.process_page(p_or_d)
|
||||
current_entries.extend(page_entries)
|
||||
|
||||
return self.update_entries_with_ids(current_entries, previous_entries)
|
||||
return self.update_entries_with_ids(current_entries, user)
|
||||
|
||||
def process_page(self, page):
|
||||
page_id = page["id"]
|
||||
|
@ -241,19 +246,11 @@ class NotionToJsonl(TextToJsonl):
|
|||
title = None
|
||||
return title, content
|
||||
|
||||
def update_entries_with_ids(self, current_entries, previous_entries):
|
||||
def update_entries_with_ids(self, current_entries, user: KhojUser = None):
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||
current_entries, previous_entries, key="compiled", logger=logger
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries, Embeddings.EmbeddingsType.NOTION, key="compiled", logger=logger, user=user
|
||||
)
|
||||
|
||||
with timer("Write Notion entries to JSONL file", logger):
|
||||
# Process Each Entry from all Notion entries
|
||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||
jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
|
||||
|
||||
return entries_with_ids
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
|
|
@ -5,28 +5,26 @@ from typing import Iterable, List, Tuple
|
|||
|
||||
# Internal Packages
|
||||
from khoj.processor.org_mode import orgnode
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.utils import state
|
||||
from database.models import Embeddings, KhojUser
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrgToJsonl(TextToJsonl):
|
||||
def __init__(self, config: TextContentConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
class OrgToJsonl(TextEmbeddings):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(
|
||||
self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True
|
||||
) -> List[Tuple[int, Entry]]:
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
output_file = self.config.compressed_jsonl
|
||||
index_heading_entries = self.config.index_heading_entries
|
||||
index_heading_entries = True
|
||||
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
|
@ -47,19 +45,17 @@ class OrgToJsonl(TextToJsonl):
|
|||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
Embeddings.EmbeddingsType.ORG,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
with timer("Write org entries to JSONL file", logger):
|
||||
entries = map(lambda entry: entry[1], entries_with_ids)
|
||||
jsonl_data = self.convert_org_entries_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
|
||||
return entries_with_ids
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_org_entries(org_files: dict[str, str]):
|
||||
|
|
|
@ -1,28 +1,31 @@
|
|||
# Standard Packages
|
||||
import os
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
import base64
|
||||
|
||||
# External Packages
|
||||
from langchain.document_loaders import PyMuPDFLoader
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from database.models import Embeddings, KhojUser
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfToJsonl(TextToJsonl):
|
||||
# Define Functions
|
||||
def process(self, previous_entries=[], files: dict[str, str] = None, full_corpus: bool = True):
|
||||
# Extract required fields from config
|
||||
output_file = self.config.compressed_jsonl
|
||||
class PdfToJsonl(TextEmbeddings):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
|
@ -40,19 +43,17 @@ class PdfToJsonl(TextToJsonl):
|
|||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
Embeddings.EmbeddingsType.PDF,
|
||||
"compiled",
|
||||
logger,
|
||||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
)
|
||||
|
||||
with timer("Write PDF entries to JSONL file", logger):
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||
jsonl_data = PdfToJsonl.convert_pdf_maps_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
|
||||
return entries_with_ids
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_pdf_entries(pdf_files):
|
||||
|
@ -62,7 +63,7 @@ class PdfToJsonl(TextToJsonl):
|
|||
entry_to_location_map = []
|
||||
for pdf_file in pdf_files:
|
||||
try:
|
||||
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PyPDFLoader expects a file path
|
||||
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
|
||||
tmp_file = f"tmp_pdf_file.pdf"
|
||||
with open(f"{tmp_file}", "wb") as f:
|
||||
bytes = pdf_files[pdf_file]
|
||||
|
|
|
@ -4,22 +4,23 @@ from pathlib import Path
|
|||
from typing import List, Tuple
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.jsonl import compress_jsonl_data
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||
from database.models import Embeddings, KhojUser, LocalPlaintextConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlaintextToJsonl(TextToJsonl):
|
||||
class PlaintextToJsonl(TextEmbeddings):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Define Functions
|
||||
def process(
|
||||
self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True
|
||||
) -> List[Tuple[int, Entry]]:
|
||||
output_file = self.config.compressed_jsonl
|
||||
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
|
@ -37,19 +38,17 @@ class PlaintextToJsonl(TextToJsonl):
|
|||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
||||
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
|
||||
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||
current_entries,
|
||||
Embeddings.EmbeddingsType.PLAINTEXT,
|
||||
key="compiled",
|
||||
logger=logger,
|
||||
deletion_filenames=deletion_file_names,
|
||||
user=user,
|
||||
regenerate=regenerate,
|
||||
)
|
||||
|
||||
with timer("Write entries to JSONL file", logger):
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||
plaintext_data = PlaintextToJsonl.convert_entries_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
compress_jsonl_data(plaintext_data, output_file)
|
||||
|
||||
return entries_with_ids
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
|
||||
|
|
|
@ -2,24 +2,33 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Callable, List, Tuple, Set
|
||||
import uuid
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, List, Tuple, Set, Any
|
||||
from khoj.utils.helpers import timer
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.rawconfig import Entry, TextConfigBase
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.processor.embeddings import EmbeddingsModel
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from database.models import KhojUser, Embeddings, EmbeddingsDates
|
||||
from database.adapters import EmbeddingsAdapters
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToJsonl(ABC):
|
||||
def __init__(self, config: TextConfigBase):
|
||||
class TextEmbeddings(ABC):
|
||||
def __init__(self, config: Any = None):
|
||||
self.embeddings_model = EmbeddingsModel()
|
||||
self.config = config
|
||||
self.date_filter = DateFilter()
|
||||
|
||||
@abstractmethod
|
||||
def process(
|
||||
self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True
|
||||
) -> List[Tuple[int, Entry]]:
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
|
@ -38,6 +47,7 @@ class TextToJsonl(ABC):
|
|||
|
||||
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
|
||||
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
|
||||
corpus_id = uuid.uuid4()
|
||||
|
||||
# Split entry into chunks of max tokens
|
||||
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
|
||||
|
@ -57,11 +67,103 @@ class TextToJsonl(ABC):
|
|||
raw=entry.raw,
|
||||
heading=entry.heading,
|
||||
file=entry.file,
|
||||
corpus_id=corpus_id,
|
||||
)
|
||||
)
|
||||
|
||||
return chunked_entries
|
||||
|
||||
def update_embeddings(
|
||||
self,
|
||||
current_entries: List[Entry],
|
||||
file_type: str,
|
||||
key="compiled",
|
||||
logger: logging.Logger = None,
|
||||
deletion_filenames: Set[str] = None,
|
||||
user: KhojUser = None,
|
||||
regenerate: bool = False,
|
||||
):
|
||||
with timer("Construct current entry hashes", logger):
|
||||
hashes_by_file = dict[str, set[str]]()
|
||||
current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries))
|
||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||
for entry in tqdm(current_entries, desc="Hashing Entries"):
|
||||
hashes_by_file.setdefault(entry.file, set()).add(TextEmbeddings.hash_func(key)(entry))
|
||||
|
||||
num_deleted_embeddings = 0
|
||||
with timer("Preparing dataset for regeneration", logger):
|
||||
if regenerate:
|
||||
logger.info(f"Deleting all embeddings for file type {file_type}")
|
||||
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
|
||||
|
||||
num_new_embeddings = 0
|
||||
with timer("Identify hashes for adding new entries", logger):
|
||||
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
|
||||
hashes_for_file = hashes_by_file[file]
|
||||
hashes_to_process = set()
|
||||
existing_entries = Embeddings.objects.filter(
|
||||
user=user, hashed_value__in=hashes_for_file, file_type=file_type
|
||||
)
|
||||
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
||||
hashes_to_process = hashes_for_file - existing_entry_hashes
|
||||
# for hashed_val in hashes_for_file:
|
||||
# if not EmbeddingsAdapters.does_embedding_exist(user, hashed_val):
|
||||
# hashes_to_process.add(hashed_val)
|
||||
|
||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||
embeddings = self.embeddings_model.embed_documents(data_to_embed)
|
||||
|
||||
with timer("Update the database with new vector embeddings", logger):
|
||||
embeddings_to_create = []
|
||||
for hashed_val, embedding in zip(hashes_to_process, embeddings):
|
||||
entry = hash_to_current_entries[hashed_val]
|
||||
embeddings_to_create.append(
|
||||
Embeddings(
|
||||
user=user,
|
||||
embeddings=embedding,
|
||||
raw=entry.raw,
|
||||
compiled=entry.compiled,
|
||||
heading=entry.heading,
|
||||
file_path=entry.file,
|
||||
file_type=file_type,
|
||||
hashed_value=hashed_val,
|
||||
corpus_id=entry.corpus_id,
|
||||
)
|
||||
)
|
||||
new_embeddings = Embeddings.objects.bulk_create(embeddings_to_create)
|
||||
num_new_embeddings += len(new_embeddings)
|
||||
|
||||
dates_to_create = []
|
||||
with timer("Create new date associations for new embeddings", logger):
|
||||
for embedding in new_embeddings:
|
||||
dates = self.date_filter.extract_dates(embedding.raw)
|
||||
for date in dates:
|
||||
dates_to_create.append(
|
||||
EmbeddingsDates(
|
||||
date=date,
|
||||
embeddings=embedding,
|
||||
)
|
||||
)
|
||||
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
|
||||
if len(new_dates) > 0:
|
||||
logger.info(f"Created {len(new_dates)} new date entries")
|
||||
|
||||
with timer("Identify hashes for removed entries", logger):
|
||||
for file in hashes_by_file:
|
||||
existing_entry_hashes = EmbeddingsAdapters.get_existing_entry_hashes_by_file(user, file)
|
||||
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
|
||||
num_deleted_embeddings += len(to_delete_entry_hashes)
|
||||
EmbeddingsAdapters.delete_embedding_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
||||
|
||||
with timer("Identify hashes for deleting entries", logger):
|
||||
if deletion_filenames is not None:
|
||||
for file_path in deletion_filenames:
|
||||
deleted_count = EmbeddingsAdapters.delete_embedding_by_file(user, file_path)
|
||||
num_deleted_embeddings += deleted_count
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def mark_entries_for_update(
|
||||
current_entries: List[Entry],
|
||||
|
@ -72,11 +174,11 @@ class TextToJsonl(ABC):
|
|||
):
|
||||
# Hash all current and previous entries to identify new entries
|
||||
with timer("Hash previous, current entries", logger):
|
||||
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
|
||||
previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries))
|
||||
current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries))
|
||||
previous_entry_hashes = list(map(TextEmbeddings.hash_func(key), previous_entries))
|
||||
if deletion_filenames is not None:
|
||||
deletion_entries = [entry for entry in previous_entries if entry.file in deletion_filenames]
|
||||
deletion_entry_hashes = list(map(TextToJsonl.hash_func(key), deletion_entries))
|
||||
deletion_entry_hashes = list(map(TextEmbeddings.hash_func(key), deletion_entries))
|
||||
else:
|
||||
deletion_entry_hashes = []
|
||||
|
||||
|
|
|
@ -2,14 +2,15 @@
|
|||
import concurrent.futures
|
||||
import math
|
||||
import time
|
||||
import yaml
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Optional, Union, Any
|
||||
import asyncio
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, HTTPException, Header, Request
|
||||
from sentence_transformers import util
|
||||
from fastapi import APIRouter, HTTPException, Header, Request, Depends
|
||||
from starlette.authentication import requires
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_server
|
||||
|
@ -20,7 +21,6 @@ from khoj.search_filter.word_filter import WordFilter
|
|||
from khoj.utils.config import TextSearchModel
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
|
||||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
FullConfig,
|
||||
ProcessorConfig,
|
||||
SearchConfig,
|
||||
|
@ -48,11 +48,74 @@ from khoj.processor.conversation.openai.gpt import extract_questions
|
|||
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
|
||||
from fastapi.requests import Request
|
||||
|
||||
from database import adapters
|
||||
from database.adapters import EmbeddingsAdapters
|
||||
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
|
||||
|
||||
|
||||
# Initialize Router
|
||||
api = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def map_config_to_object(content_type: str):
|
||||
if content_type == "org":
|
||||
return LocalOrgConfig
|
||||
if content_type == "markdown":
|
||||
return LocalMarkdownConfig
|
||||
if content_type == "pdf":
|
||||
return LocalPdfConfig
|
||||
if content_type == "plaintext":
|
||||
return LocalPlaintextConfig
|
||||
|
||||
|
||||
async def map_config_to_db(config: FullConfig, user: KhojUser):
|
||||
if config.content_type:
|
||||
if config.content_type.org:
|
||||
await LocalOrgConfig.objects.filter(user=user).adelete()
|
||||
await LocalOrgConfig.objects.acreate(
|
||||
input_files=config.content_type.org.input_files,
|
||||
input_filter=config.content_type.org.input_filter,
|
||||
index_heading_entries=config.content_type.org.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.markdown:
|
||||
await LocalMarkdownConfig.objects.filter(user=user).adelete()
|
||||
await LocalMarkdownConfig.objects.acreate(
|
||||
input_files=config.content_type.markdown.input_files,
|
||||
input_filter=config.content_type.markdown.input_filter,
|
||||
index_heading_entries=config.content_type.markdown.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.pdf:
|
||||
await LocalPdfConfig.objects.filter(user=user).adelete()
|
||||
await LocalPdfConfig.objects.acreate(
|
||||
input_files=config.content_type.pdf.input_files,
|
||||
input_filter=config.content_type.pdf.input_filter,
|
||||
index_heading_entries=config.content_type.pdf.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.plaintext:
|
||||
await LocalPlaintextConfig.objects.filter(user=user).adelete()
|
||||
await LocalPlaintextConfig.objects.acreate(
|
||||
input_files=config.content_type.plaintext.input_files,
|
||||
input_filter=config.content_type.plaintext.input_filter,
|
||||
index_heading_entries=config.content_type.plaintext.index_heading_entries,
|
||||
user=user,
|
||||
)
|
||||
if config.content_type.github:
|
||||
await adapters.set_user_github_config(
|
||||
user=user,
|
||||
pat_token=config.content_type.github.pat_token,
|
||||
repos=config.content_type.github.repos,
|
||||
)
|
||||
if config.content_type.notion:
|
||||
await adapters.set_notion_config(
|
||||
user=user,
|
||||
token=config.content_type.notion.token,
|
||||
)
|
||||
|
||||
|
||||
# If it's a demo instance, prevent updating any of the configuration.
|
||||
if not state.demo:
|
||||
|
||||
|
@ -64,7 +127,10 @@ if not state.demo:
|
|||
state.processor_config = configure_processor(state.config.processor)
|
||||
|
||||
@api.get("/config/data", response_model=FullConfig)
|
||||
def get_config_data():
|
||||
def get_config_data(request: Request):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
enabled_content = EmbeddingsAdapters.get_unique_file_types(user)
|
||||
|
||||
return state.config
|
||||
|
||||
@api.post("/config/data")
|
||||
|
@ -73,20 +139,19 @@ if not state.demo:
|
|||
updated_config: FullConfig,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
state.config = updated_config
|
||||
with open(state.config_file, "w") as outfile:
|
||||
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
|
||||
outfile.close()
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
await map_config_to_db(updated_config, user)
|
||||
|
||||
configuration_update_metadata = dict()
|
||||
configuration_update_metadata = {}
|
||||
|
||||
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
|
||||
|
||||
if state.config.content_type is not None:
|
||||
configuration_update_metadata["github"] = state.config.content_type.github is not None
|
||||
configuration_update_metadata["notion"] = state.config.content_type.notion is not None
|
||||
configuration_update_metadata["org"] = state.config.content_type.org is not None
|
||||
configuration_update_metadata["pdf"] = state.config.content_type.pdf is not None
|
||||
configuration_update_metadata["markdown"] = state.config.content_type.markdown is not None
|
||||
configuration_update_metadata["plugins"] = state.config.content_type.plugins is not None
|
||||
configuration_update_metadata["github"] = "github" in enabled_content
|
||||
configuration_update_metadata["notion"] = "notion" in enabled_content
|
||||
configuration_update_metadata["org"] = "org" in enabled_content
|
||||
configuration_update_metadata["pdf"] = "pdf" in enabled_content
|
||||
configuration_update_metadata["markdown"] = "markdown" in enabled_content
|
||||
|
||||
if state.config.processor is not None:
|
||||
configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None
|
||||
|
@ -101,6 +166,7 @@ if not state.demo:
|
|||
return state.config
|
||||
|
||||
@api.post("/config/data/content_type/github", status_code=200)
|
||||
@requires("authenticated")
|
||||
async def set_content_config_github_data(
|
||||
request: Request,
|
||||
updated_config: Union[GithubContentConfig, None],
|
||||
|
@ -108,10 +174,13 @@ if not state.demo:
|
|||
):
|
||||
_initialize_config()
|
||||
|
||||
if not state.config.content_type:
|
||||
state.config.content_type = ContentConfig(**{"github": updated_config})
|
||||
else:
|
||||
state.config.content_type.github = updated_config
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
await adapters.set_user_github_config(
|
||||
user=user,
|
||||
pat_token=updated_config.pat_token,
|
||||
repos=updated_config.repos,
|
||||
)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -121,11 +190,7 @@ if not state.demo:
|
|||
metadata={"content_type": "github"},
|
||||
)
|
||||
|
||||
try:
|
||||
save_config_to_file_updated_state()
|
||||
return {"status": "ok"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
return {"status": "ok"}
|
||||
|
||||
@api.post("/config/data/content_type/notion", status_code=200)
|
||||
async def set_content_config_notion_data(
|
||||
|
@ -135,10 +200,12 @@ if not state.demo:
|
|||
):
|
||||
_initialize_config()
|
||||
|
||||
if not state.config.content_type:
|
||||
state.config.content_type = ContentConfig(**{"notion": updated_config})
|
||||
else:
|
||||
state.config.content_type.notion = updated_config
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
await adapters.set_notion_config(
|
||||
user=user,
|
||||
token=updated_config.token,
|
||||
)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -148,11 +215,7 @@ if not state.demo:
|
|||
metadata={"content_type": "notion"},
|
||||
)
|
||||
|
||||
try:
|
||||
save_config_to_file_updated_state()
|
||||
return {"status": "ok"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
return {"status": "ok"}
|
||||
|
||||
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
|
||||
async def remove_content_config_data(
|
||||
|
@ -160,8 +223,7 @@ if not state.demo:
|
|||
content_type: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
if not state.config or not state.config.content_type:
|
||||
return {"status": "ok"}
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -171,31 +233,13 @@ if not state.demo:
|
|||
metadata={"content_type": content_type},
|
||||
)
|
||||
|
||||
if state.config.content_type:
|
||||
state.config.content_type[content_type] = None
|
||||
content_object = map_config_to_object(content_type)
|
||||
await content_object.objects.filter(user=user).adelete()
|
||||
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
|
||||
|
||||
if content_type == "github":
|
||||
state.content_index.github = None
|
||||
elif content_type == "notion":
|
||||
state.content_index.notion = None
|
||||
elif content_type == "plugins":
|
||||
state.content_index.plugins = None
|
||||
elif content_type == "pdf":
|
||||
state.content_index.pdf = None
|
||||
elif content_type == "markdown":
|
||||
state.content_index.markdown = None
|
||||
elif content_type == "org":
|
||||
state.content_index.org = None
|
||||
elif content_type == "plaintext":
|
||||
state.content_index.plaintext = None
|
||||
else:
|
||||
logger.warning(f"Request to delete unknown content type: {content_type} via API")
|
||||
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
|
||||
|
||||
try:
|
||||
save_config_to_file_updated_state()
|
||||
return {"status": "ok"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
return {"status": "ok"}
|
||||
|
||||
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
|
||||
async def remove_processor_conversation_config_data(
|
||||
|
@ -228,6 +272,7 @@ if not state.demo:
|
|||
return {"status": "error", "message": str(e)}
|
||||
|
||||
@api.post("/config/data/content_type/{content_type}", status_code=200)
|
||||
# @requires("authenticated")
|
||||
async def set_content_config_data(
|
||||
request: Request,
|
||||
content_type: str,
|
||||
|
@ -236,10 +281,10 @@ if not state.demo:
|
|||
):
|
||||
_initialize_config()
|
||||
|
||||
if not state.config.content_type:
|
||||
state.config.content_type = ContentConfig(**{content_type: updated_config})
|
||||
else:
|
||||
state.config.content_type[content_type] = updated_config
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
content_object = map_config_to_object(content_type)
|
||||
await adapters.set_text_content_config(user, content_object, updated_config)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -249,11 +294,7 @@ if not state.demo:
|
|||
metadata={"content_type": content_type},
|
||||
)
|
||||
|
||||
try:
|
||||
save_config_to_file_updated_state()
|
||||
return {"status": "ok"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
return {"status": "ok"}
|
||||
|
||||
@api.post("/config/data/processor/conversation/openai", status_code=200)
|
||||
async def set_processor_openai_config_data(
|
||||
|
@ -337,24 +378,23 @@ def get_default_config_data():
|
|||
|
||||
|
||||
@api.get("/config/types", response_model=List[str])
|
||||
def get_config_types():
|
||||
"""Get configured content types"""
|
||||
if state.config is None or state.config.content_type is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Content types not configured. Configure at least one content type on server and restart it.",
|
||||
)
|
||||
def get_config_types(
|
||||
request: Request,
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user)
|
||||
|
||||
configured_content_types = list(enabled_file_types)
|
||||
|
||||
if state.config and state.config.content_type:
|
||||
for ctype in state.config.content_type.dict(exclude_none=True):
|
||||
configured_content_types.append(ctype)
|
||||
|
||||
configured_content_types = state.config.content_type.dict(exclude_none=True)
|
||||
return [
|
||||
search_type.value
|
||||
for search_type in SearchType
|
||||
if (
|
||||
search_type.value in configured_content_types
|
||||
and getattr(state.content_index, search_type.value) is not None
|
||||
)
|
||||
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
|
||||
or search_type == SearchType.All
|
||||
if (search_type.value in configured_content_types) or search_type == SearchType.All
|
||||
]
|
||||
|
||||
|
||||
|
@ -372,6 +412,7 @@ async def search(
|
|||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
start_time = time.time()
|
||||
|
||||
# Run validation checks
|
||||
|
@ -390,10 +431,11 @@ async def search(
|
|||
search_futures: List[concurrent.futures.Future] = []
|
||||
|
||||
# return cached results, if available
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
||||
if query_cache_key in state.query_cache:
|
||||
logger.debug(f"Return response from query cache")
|
||||
return state.query_cache[query_cache_key]
|
||||
if user:
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
||||
if query_cache_key in state.query_cache[user.uuid]:
|
||||
logger.debug(f"Return response from query cache")
|
||||
return state.query_cache[user.uuid][query_cache_key]
|
||||
|
||||
# Encode query with filter terms removed
|
||||
defiltered_query = user_query
|
||||
|
@ -407,84 +449,31 @@ async def search(
|
|||
]
|
||||
if text_search_models:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
encoded_asymmetric_query = util.normalize_embeddings(
|
||||
text_search_models[0].bi_encoder.encode(
|
||||
[defiltered_query],
|
||||
convert_to_tensor=True,
|
||||
device=state.device,
|
||||
)
|
||||
)
|
||||
encoded_asymmetric_query = state.embeddings_model.embed_query(defiltered_query)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
|
||||
# query org-mode notes
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.search_models.text_search,
|
||||
state.content_index.org,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
if (
|
||||
(t == SearchType.Markdown or t == SearchType.All)
|
||||
and state.content_index.markdown
|
||||
and state.search_models.text_search
|
||||
):
|
||||
if t in [
|
||||
SearchType.All,
|
||||
SearchType.Org,
|
||||
SearchType.Markdown,
|
||||
SearchType.Github,
|
||||
SearchType.Notion,
|
||||
SearchType.Plaintext,
|
||||
]:
|
||||
# query markdown notes
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user,
|
||||
user_query,
|
||||
state.search_models.text_search,
|
||||
state.content_index.markdown,
|
||||
t,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
if (
|
||||
(t == SearchType.Github or t == SearchType.All)
|
||||
and state.content_index.github
|
||||
and state.search_models.text_search
|
||||
):
|
||||
# query github issues
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.search_models.text_search,
|
||||
state.content_index.github,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
|
||||
# query pdf files
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.search_models.text_search,
|
||||
state.content_index.pdf,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
|
||||
elif (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
|
||||
# query images
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
|
@ -497,70 +486,6 @@ async def search(
|
|||
)
|
||||
]
|
||||
|
||||
if (
|
||||
(t == SearchType.All or t in SearchType)
|
||||
and state.content_index.plugins
|
||||
and state.search_models.plugin_search
|
||||
):
|
||||
# query specified plugin type
|
||||
# Get plugin content, search model for specified search type, or the first one if none specified
|
||||
plugin_search = state.search_models.plugin_search.get(t.value) or next(
|
||||
iter(state.search_models.plugin_search.values())
|
||||
)
|
||||
plugin_content = state.content_index.plugins.get(t.value) or next(
|
||||
iter(state.content_index.plugins.values())
|
||||
)
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
plugin_search,
|
||||
plugin_content,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
if (
|
||||
(t == SearchType.Notion or t == SearchType.All)
|
||||
and state.content_index.notion
|
||||
and state.search_models.text_search
|
||||
):
|
||||
# query notion pages
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.search_models.text_search,
|
||||
state.content_index.notion,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
if (
|
||||
(t == SearchType.Plaintext or t == SearchType.All)
|
||||
and state.content_index.plaintext
|
||||
and state.search_models.text_search
|
||||
):
|
||||
# query plaintext files
|
||||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user_query,
|
||||
state.search_models.text_search,
|
||||
state.content_index.plaintext,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
dedupe=dedupe or True,
|
||||
)
|
||||
]
|
||||
|
||||
# Query across each requested content types in parallel
|
||||
with timer("Query took", logger):
|
||||
for search_future in concurrent.futures.as_completed(search_futures):
|
||||
|
@ -576,15 +501,19 @@ async def search(
|
|||
count=results_count,
|
||||
)
|
||||
else:
|
||||
hits, entries = await search_future.result()
|
||||
hits = await search_future.result()
|
||||
# Collate results
|
||||
results += text_search.collate_results(hits, entries, results_count)
|
||||
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||
|
||||
# Sort results across all content types and take top results
|
||||
results = sorted(results, key=lambda x: float(x.score), reverse=True)[:results_count]
|
||||
if r:
|
||||
results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
|
||||
else:
|
||||
# Sort results across all content types and take top results
|
||||
results = sorted(results, key=lambda x: float(x.score))[:results_count]
|
||||
|
||||
# Cache results
|
||||
state.query_cache[query_cache_key] = results
|
||||
if user:
|
||||
state.query_cache[user.uuid][query_cache_key] = results
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -596,8 +525,6 @@ async def search(
|
|||
host=host,
|
||||
)
|
||||
|
||||
state.previous_query = user_query
|
||||
|
||||
end_time = time.time()
|
||||
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
||||
|
||||
|
@ -614,12 +541,13 @@ def update(
|
|||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
if not state.config:
|
||||
error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
|
||||
logger.warning(error_msg)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
try:
|
||||
configure_server(state.config, regenerate=force, search_type=t)
|
||||
configure_server(state.config, regenerate=force, search_type=t, user=user)
|
||||
except Exception as e:
|
||||
error_msg = f"🚨 Failed to update server via API: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
|
@ -774,6 +702,7 @@ async def extract_references_and_questions(
|
|||
n: int,
|
||||
conversation_type: ConversationCommand = ConversationCommand.Default,
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
# Load Conversation History
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
|
@ -781,7 +710,7 @@ async def extract_references_and_questions(
|
|||
compiled_references: List[Any] = []
|
||||
inferred_queries: List[str] = []
|
||||
|
||||
if state.content_index is None:
|
||||
if not EmbeddingsAdapters.user_has_embeddings(user=user):
|
||||
logger.warning(
|
||||
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
||||
)
|
||||
|
|
|
@ -10,6 +10,7 @@ from khoj.utils.helpers import ConversationCommand, timer, log_telemetry
|
|||
from khoj.processor.conversation.openai.gpt import converse
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
||||
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
|
||||
from database.models import KhojUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -40,11 +41,13 @@ def update_telemetry_state(
|
|||
host: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
||||
user_state = {
|
||||
"client_host": request.client.host if request.client else None,
|
||||
"user_agent": user_agent or "unknown",
|
||||
"referer": referer or "unknown",
|
||||
"host": host or "unknown",
|
||||
"server_id": str(user.uuid) if user else None,
|
||||
}
|
||||
|
||||
if metadata:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Standard Packages
|
||||
import logging
|
||||
from typing import Optional, Union, Dict
|
||||
import asyncio
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, HTTPException, Header, Request, Response, UploadFile
|
||||
|
@ -9,31 +10,30 @@ from khoj.routers.helpers import update_telemetry_state
|
|||
|
||||
# Internal Packages
|
||||
from khoj.utils import state, constants
|
||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
|
||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
|
||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||
from khoj.utils.rawconfig import ContentConfig, TextContentConfig
|
||||
from khoj.search_type import text_search, image_search
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.constants import default_config
|
||||
from khoj.utils.helpers import LRU, get_file_type
|
||||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
FullConfig,
|
||||
SearchConfig,
|
||||
)
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.utils.config import (
|
||||
ContentIndex,
|
||||
SearchModels,
|
||||
)
|
||||
from database.models import (
|
||||
KhojUser,
|
||||
GithubConfig,
|
||||
NotionConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -68,14 +68,14 @@ async def update(
|
|||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
if x_api_key != "secret":
|
||||
raise HTTPException(status_code=401, detail="Invalid API Key")
|
||||
state.config_lock.acquire()
|
||||
try:
|
||||
logger.info(f"📬 Updating content index via API call by {client} client")
|
||||
org_files: Dict[str, str] = {}
|
||||
markdown_files: Dict[str, str] = {}
|
||||
pdf_files: Dict[str, str] = {}
|
||||
pdf_files: Dict[str, bytes] = {}
|
||||
plaintext_files: Dict[str, str] = {}
|
||||
|
||||
for file in files:
|
||||
|
@ -86,7 +86,7 @@ async def update(
|
|||
elif file_type == "markdown":
|
||||
dict_to_update = markdown_files
|
||||
elif file_type == "pdf":
|
||||
dict_to_update = pdf_files
|
||||
dict_to_update = pdf_files # type: ignore
|
||||
elif file_type == "plaintext":
|
||||
dict_to_update = plaintext_files
|
||||
|
||||
|
@ -120,30 +120,31 @@ async def update(
|
|||
github=None,
|
||||
notion=None,
|
||||
plaintext=None,
|
||||
plugins=None,
|
||||
)
|
||||
state.config.content_type = default_content_config
|
||||
save_config_to_file_updated_state()
|
||||
configure_search(state.search_models, state.config.search_type)
|
||||
|
||||
# Extract required fields from config
|
||||
state.content_index = configure_content(
|
||||
loop = asyncio.get_event_loop()
|
||||
state.content_index = await loop.run_in_executor(
|
||||
None,
|
||||
configure_content,
|
||||
state.content_index,
|
||||
state.config.content_type,
|
||||
indexer_input.dict(),
|
||||
state.search_models,
|
||||
regenerate=force,
|
||||
t=t,
|
||||
full_corpus=False,
|
||||
force,
|
||||
t,
|
||||
False,
|
||||
user,
|
||||
)
|
||||
|
||||
logger.info(f"Finished processing batch indexing request")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -167,11 +168,6 @@ def configure_search(search_models: SearchModels, search_config: Optional[Search
|
|||
if search_models is None:
|
||||
search_models = SearchModels()
|
||||
|
||||
# Initialize Search Models
|
||||
if search_config.asymmetric:
|
||||
logger.info("🔍 📜 Setting up text search model")
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
|
||||
if search_config.image:
|
||||
logger.info("🔍 🌄 Setting up image search model")
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
@ -187,16 +183,9 @@ def configure_content(
|
|||
regenerate: bool = False,
|
||||
t: Optional[Union[state.SearchType, str]] = None,
|
||||
full_corpus: bool = True,
|
||||
user: KhojUser = None,
|
||||
) -> Optional[ContentIndex]:
|
||||
def has_valid_text_config(config: TextContentConfig):
|
||||
return config.input_files or config.input_filter
|
||||
|
||||
# Run Validation Checks
|
||||
if content_config is None:
|
||||
logger.warning("🚨 No Content configuration available.")
|
||||
return None
|
||||
if content_index is None:
|
||||
content_index = ContentIndex()
|
||||
content_index = ContentIndex()
|
||||
|
||||
if t in [type.value for type in state.SearchType]:
|
||||
t = state.SearchType(t).value
|
||||
|
@ -209,59 +198,30 @@ def configure_content(
|
|||
|
||||
try:
|
||||
# Initialize Org Notes Search
|
||||
if (
|
||||
(t == None or t == state.SearchType.Org.value)
|
||||
and ((content_config.org and has_valid_text_config(content_config.org)) or files["org"])
|
||||
and search_models.text_search
|
||||
):
|
||||
if content_config.org == None:
|
||||
logger.info("🦄 No configuration for orgmode notes. Using default configuration.")
|
||||
default_configuration = default_config["content-type"]["org"] # type: ignore
|
||||
content_config.org = TextContentConfig(
|
||||
compressed_jsonl=default_configuration["compressed-jsonl"],
|
||||
embeddings_file=default_configuration["embeddings-file"],
|
||||
)
|
||||
|
||||
if (t == None or t == state.SearchType.Org.value) and files["org"]:
|
||||
logger.info("🦄 Setting up search for orgmode notes")
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
content_index.org = text_search.setup(
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
files.get("org"),
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup org: {e}", exc_info=True)
|
||||
|
||||
try:
|
||||
# Initialize Markdown Search
|
||||
if (
|
||||
(t == None or t == state.SearchType.Markdown.value)
|
||||
and ((content_config.markdown and has_valid_text_config(content_config.markdown)) or files["markdown"])
|
||||
and search_models.text_search
|
||||
and files["markdown"]
|
||||
):
|
||||
if content_config.markdown == None:
|
||||
logger.info("💎 No configuration for markdown notes. Using default configuration.")
|
||||
default_configuration = default_config["content-type"]["markdown"] # type: ignore
|
||||
content_config.markdown = TextContentConfig(
|
||||
compressed_jsonl=default_configuration["compressed-jsonl"],
|
||||
embeddings_file=default_configuration["embeddings-file"],
|
||||
)
|
||||
|
||||
if (t == None or t == state.SearchType.Markdown.value) and files["markdown"]:
|
||||
logger.info("💎 Setting up search for markdown notes")
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
content_index.markdown = text_search.setup(
|
||||
text_search.setup(
|
||||
MarkdownToJsonl,
|
||||
files.get("markdown"),
|
||||
content_config.markdown,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -269,30 +229,15 @@ def configure_content(
|
|||
|
||||
try:
|
||||
# Initialize PDF Search
|
||||
if (
|
||||
(t == None or t == state.SearchType.Pdf.value)
|
||||
and ((content_config.pdf and has_valid_text_config(content_config.pdf)) or files["pdf"])
|
||||
and search_models.text_search
|
||||
and files["pdf"]
|
||||
):
|
||||
if content_config.pdf == None:
|
||||
logger.info("🖨️ No configuration for pdf notes. Using default configuration.")
|
||||
default_configuration = default_config["content-type"]["pdf"] # type: ignore
|
||||
content_config.pdf = TextContentConfig(
|
||||
compressed_jsonl=default_configuration["compressed-jsonl"],
|
||||
embeddings_file=default_configuration["embeddings-file"],
|
||||
)
|
||||
|
||||
if (t == None or t == state.SearchType.Pdf.value) and files["pdf"]:
|
||||
logger.info("🖨️ Setting up search for pdf")
|
||||
# Extract Entries, Generate PDF Embeddings
|
||||
content_index.pdf = text_search.setup(
|
||||
text_search.setup(
|
||||
PdfToJsonl,
|
||||
files.get("pdf"),
|
||||
content_config.pdf,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -300,30 +245,15 @@ def configure_content(
|
|||
|
||||
try:
|
||||
# Initialize Plaintext Search
|
||||
if (
|
||||
(t == None or t == state.SearchType.Plaintext.value)
|
||||
and ((content_config.plaintext and has_valid_text_config(content_config.plaintext)) or files["plaintext"])
|
||||
and search_models.text_search
|
||||
and files["plaintext"]
|
||||
):
|
||||
if content_config.plaintext == None:
|
||||
logger.info("📄 No configuration for plaintext notes. Using default configuration.")
|
||||
default_configuration = default_config["content-type"]["plaintext"] # type: ignore
|
||||
content_config.plaintext = TextContentConfig(
|
||||
compressed_jsonl=default_configuration["compressed-jsonl"],
|
||||
embeddings_file=default_configuration["embeddings-file"],
|
||||
)
|
||||
|
||||
if (t == None or t == state.SearchType.Plaintext.value) and files["plaintext"]:
|
||||
logger.info("📄 Setting up search for plaintext")
|
||||
# Extract Entries, Generate Plaintext Embeddings
|
||||
content_index.plaintext = text_search.setup(
|
||||
text_search.setup(
|
||||
PlaintextToJsonl,
|
||||
files.get("plaintext"),
|
||||
content_config.plaintext,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -331,7 +261,12 @@ def configure_content(
|
|||
|
||||
try:
|
||||
# Initialize Image Search
|
||||
if (t == None or t == state.SearchType.Image.value) and content_config.image and search_models.image_search:
|
||||
if (
|
||||
(t == None or t == state.SearchType.Image.value)
|
||||
and content_config
|
||||
and content_config.image
|
||||
and search_models.image_search
|
||||
):
|
||||
logger.info("🌄 Setting up search for images")
|
||||
# Extract Entries, Generate Image Embeddings
|
||||
content_index.image = image_search.setup(
|
||||
|
@ -342,17 +277,17 @@ def configure_content(
|
|||
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
|
||||
|
||||
try:
|
||||
if (t == None or t == state.SearchType.Github.value) and content_config.github and search_models.text_search:
|
||||
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||
if (t == None or t == state.SearchType.Github.value) and github_config is not None:
|
||||
logger.info("🐙 Setting up search for github")
|
||||
# Extract Entries, Generate Github Embeddings
|
||||
content_index.github = text_search.setup(
|
||||
text_search.setup(
|
||||
GithubToJsonl,
|
||||
None,
|
||||
content_config.github,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
config=github_config,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -360,42 +295,24 @@ def configure_content(
|
|||
|
||||
try:
|
||||
# Initialize Notion Search
|
||||
if (t == None or t in state.SearchType.Notion.value) and content_config.notion and search_models.text_search:
|
||||
notion_config = NotionConfig.objects.filter(user=user).first()
|
||||
if (t == None or t in state.SearchType.Notion.value) and notion_config:
|
||||
logger.info("🔌 Setting up search for notion")
|
||||
content_index.notion = text_search.setup(
|
||||
text_search.setup(
|
||||
NotionToJsonl,
|
||||
None,
|
||||
content_config.notion,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
full_corpus=full_corpus,
|
||||
user=user,
|
||||
config=notion_config,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
|
||||
|
||||
try:
|
||||
# Initialize External Plugin Search
|
||||
if t == None and content_config.plugins and search_models.text_search:
|
||||
logger.info("🔌 Setting up search for plugins")
|
||||
content_index.plugins = {}
|
||||
for plugin_type, plugin_config in content_config.plugins.items():
|
||||
content_index.plugins[plugin_type] = text_search.setup(
|
||||
JsonlToJsonl,
|
||||
None,
|
||||
plugin_config,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
||||
full_corpus=full_corpus,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to setup Plugin: {e}", exc_info=True)
|
||||
|
||||
# Invalidate Query Cache
|
||||
state.query_cache = LRU()
|
||||
if user:
|
||||
state.query_cache[user.uuid] = LRU()
|
||||
|
||||
return content_index
|
||||
|
||||
|
@ -412,44 +329,9 @@ def load_content(
|
|||
if content_index is None:
|
||||
content_index = ContentIndex()
|
||||
|
||||
if content_config.org:
|
||||
logger.info("🦄 Loading orgmode notes")
|
||||
content_index.org = text_search.load(content_config.org, filters=[DateFilter(), WordFilter(), FileFilter()])
|
||||
if content_config.markdown:
|
||||
logger.info("💎 Loading markdown notes")
|
||||
content_index.markdown = text_search.load(
|
||||
content_config.markdown, filters=[DateFilter(), WordFilter(), FileFilter()]
|
||||
)
|
||||
if content_config.pdf:
|
||||
logger.info("🖨️ Loading pdf")
|
||||
content_index.pdf = text_search.load(content_config.pdf, filters=[DateFilter(), WordFilter(), FileFilter()])
|
||||
if content_config.plaintext:
|
||||
logger.info("📄 Loading plaintext")
|
||||
content_index.plaintext = text_search.load(
|
||||
content_config.plaintext, filters=[DateFilter(), WordFilter(), FileFilter()]
|
||||
)
|
||||
if content_config.image:
|
||||
logger.info("🌄 Loading images")
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||
)
|
||||
if content_config.github:
|
||||
logger.info("🐙 Loading github")
|
||||
content_index.github = text_search.load(
|
||||
content_config.github, filters=[DateFilter(), WordFilter(), FileFilter()]
|
||||
)
|
||||
if content_config.notion:
|
||||
logger.info("🔌 Loading notion")
|
||||
content_index.notion = text_search.load(
|
||||
content_config.notion, filters=[DateFilter(), WordFilter(), FileFilter()]
|
||||
)
|
||||
if content_config.plugins:
|
||||
logger.info("🔌 Loading plugins")
|
||||
content_index.plugins = {}
|
||||
for plugin_type, plugin_config in content_config.plugins.items():
|
||||
content_index.plugins[plugin_type] = text_search.load(
|
||||
plugin_config, filters=[DateFilter(), WordFilter(), FileFilter()]
|
||||
)
|
||||
|
||||
state.query_cache = LRU()
|
||||
return content_index
|
||||
|
|
|
@ -3,10 +3,20 @@ from fastapi import APIRouter
|
|||
from fastapi import Request
|
||||
from fastapi.responses import HTMLResponse, FileResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from khoj.utils.rawconfig import TextContentConfig, OpenAIProcessorConfig, FullConfig
|
||||
from starlette.authentication import requires
|
||||
from khoj.utils.rawconfig import (
|
||||
TextContentConfig,
|
||||
OpenAIProcessorConfig,
|
||||
FullConfig,
|
||||
GithubContentConfig,
|
||||
GithubRepoConfig,
|
||||
NotionContentConfig,
|
||||
)
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import constants, state
|
||||
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
|
||||
from database.models import KhojUser, LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
||||
|
||||
import json
|
||||
|
||||
|
@ -29,10 +39,23 @@ def chat_page(request: Request):
|
|||
return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo})
|
||||
|
||||
|
||||
def map_config_to_object(content_type: str):
|
||||
if content_type == "org":
|
||||
return LocalOrgConfig
|
||||
if content_type == "markdown":
|
||||
return LocalMarkdownConfig
|
||||
if content_type == "pdf":
|
||||
return LocalPdfConfig
|
||||
if content_type == "plaintext":
|
||||
return LocalPlaintextConfig
|
||||
|
||||
|
||||
if not state.demo:
|
||||
|
||||
@web_client.get("/config", response_class=HTMLResponse)
|
||||
def config_page(request: Request):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
|
||||
default_full_config = FullConfig(
|
||||
content_type=None,
|
||||
search_type=None,
|
||||
|
@ -41,13 +64,13 @@ if not state.demo:
|
|||
current_config = state.config or json.loads(default_full_config.json())
|
||||
|
||||
successfully_configured = {
|
||||
"pdf": False,
|
||||
"markdown": False,
|
||||
"org": False,
|
||||
"pdf": ("pdf" in enabled_content),
|
||||
"markdown": ("markdown" in enabled_content),
|
||||
"org": ("org" in enabled_content),
|
||||
"image": False,
|
||||
"github": False,
|
||||
"notion": False,
|
||||
"plaintext": False,
|
||||
"github": ("github" in enabled_content),
|
||||
"notion": ("notion" in enabled_content),
|
||||
"plaintext": ("plaintext" in enabled_content),
|
||||
"enable_offline_model": False,
|
||||
"conversation_openai": False,
|
||||
"conversation_gpt4all": False,
|
||||
|
@ -56,13 +79,7 @@ if not state.demo:
|
|||
if state.content_index:
|
||||
successfully_configured.update(
|
||||
{
|
||||
"pdf": state.content_index.pdf is not None,
|
||||
"markdown": state.content_index.markdown is not None,
|
||||
"org": state.content_index.org is not None,
|
||||
"image": state.content_index.image is not None,
|
||||
"github": state.content_index.github is not None,
|
||||
"notion": state.content_index.notion is not None,
|
||||
"plaintext": state.content_index.plaintext is not None,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -84,22 +101,29 @@ if not state.demo:
|
|||
)
|
||||
|
||||
@web_client.get("/config/content_type/github", response_class=HTMLResponse)
|
||||
@requires(["authenticated"])
|
||||
def github_config_page(request: Request):
|
||||
default_copy = constants.default_config.copy()
|
||||
default_github = default_copy["content-type"]["github"] # type: ignore
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
current_github_config = get_user_github_config(user)
|
||||
|
||||
default_config = TextContentConfig(
|
||||
compressed_jsonl=default_github["compressed-jsonl"],
|
||||
embeddings_file=default_github["embeddings-file"],
|
||||
)
|
||||
|
||||
current_config = (
|
||||
state.config.content_type.github
|
||||
if state.config and state.config.content_type and state.config.content_type.github
|
||||
else default_config
|
||||
)
|
||||
|
||||
current_config = json.loads(current_config.json())
|
||||
if current_github_config:
|
||||
raw_repos = current_github_config.githubrepoconfig.all()
|
||||
repos = []
|
||||
for repo in raw_repos:
|
||||
repos.append(
|
||||
GithubRepoConfig(
|
||||
name=repo.name,
|
||||
owner=repo.owner,
|
||||
branch=repo.branch,
|
||||
)
|
||||
)
|
||||
current_config = GithubContentConfig(
|
||||
pat_token=current_github_config.pat_token,
|
||||
repos=repos,
|
||||
)
|
||||
current_config = json.loads(current_config.json())
|
||||
else:
|
||||
current_config = {} # type: ignore
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"content_type_github_input.html", context={"request": request, "current_config": current_config}
|
||||
|
@ -107,18 +131,11 @@ if not state.demo:
|
|||
|
||||
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
|
||||
def notion_config_page(request: Request):
|
||||
default_copy = constants.default_config.copy()
|
||||
default_notion = default_copy["content-type"]["notion"] # type: ignore
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
current_notion_config = get_user_notion_config(user)
|
||||
|
||||
default_config = TextContentConfig(
|
||||
compressed_jsonl=default_notion["compressed-jsonl"],
|
||||
embeddings_file=default_notion["embeddings-file"],
|
||||
)
|
||||
|
||||
current_config = (
|
||||
state.config.content_type.notion
|
||||
if state.config and state.config.content_type and state.config.content_type.notion
|
||||
else default_config
|
||||
current_config = NotionContentConfig(
|
||||
token=current_notion_config.token if current_notion_config else "",
|
||||
)
|
||||
|
||||
current_config = json.loads(current_config.json())
|
||||
|
@ -132,18 +149,16 @@ if not state.demo:
|
|||
if content_type not in VALID_TEXT_CONTENT_TYPES:
|
||||
return templates.TemplateResponse("config.html", context={"request": request})
|
||||
|
||||
default_copy = constants.default_config.copy()
|
||||
default_content_type = default_copy["content-type"][content_type] # type: ignore
|
||||
object = map_config_to_object(content_type)
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
config = object.objects.filter(user=user).first()
|
||||
if config == None:
|
||||
config = object.objects.create(user=user)
|
||||
|
||||
default_config = TextContentConfig(
|
||||
compressed_jsonl=default_content_type["compressed-jsonl"],
|
||||
embeddings_file=default_content_type["embeddings-file"],
|
||||
)
|
||||
|
||||
current_config = (
|
||||
state.config.content_type[content_type]
|
||||
if state.config and state.config.content_type and state.config.content_type[content_type] # type: ignore
|
||||
else default_config
|
||||
current_config = TextContentConfig(
|
||||
input_files=config.input_files,
|
||||
input_filter=config.input_filter,
|
||||
index_heading_entries=config.index_heading_entries,
|
||||
)
|
||||
current_config = json.loads(current_config.json())
|
||||
|
||||
|
|
|
@ -1,16 +1,9 @@
|
|||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from typing import List
|
||||
|
||||
|
||||
class BaseFilter(ABC):
|
||||
@abstractmethod
|
||||
def load(self, entries: List[Entry], *args, **kwargs):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_filter_terms(self, query: str) -> List[str]:
|
||||
...
|
||||
|
@ -18,10 +11,6 @@ class BaseFilter(ABC):
|
|||
def can_filter(self, raw_query: str) -> bool:
|
||||
return len(self.get_filter_terms(raw_query)) > 0
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def defilter(self, query: str) -> str:
|
||||
...
|
||||
|
|
|
@ -25,72 +25,42 @@ class DateFilter(BaseFilter):
|
|||
# - dt>="last week"
|
||||
# - dt:"2 years ago"
|
||||
date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']"
|
||||
raw_date_regex = r"\d{4}-\d{2}-\d{2}"
|
||||
|
||||
def __init__(self, entry_key="compiled"):
|
||||
self.entry_key = entry_key
|
||||
self.date_to_entry_ids = defaultdict(set)
|
||||
self.cache = LRU()
|
||||
|
||||
def load(self, entries, *args, **kwargs):
|
||||
with timer("Created date filter index", logger):
|
||||
for id, entry in enumerate(entries):
|
||||
# Extract dates from entry
|
||||
for date_in_entry_string in re.findall(r"\d{4}-\d{2}-\d{2}", getattr(entry, self.entry_key)):
|
||||
# Convert date string in entry to unix timestamp
|
||||
try:
|
||||
date_in_entry = datetime.strptime(date_in_entry_string, "%Y-%m-%d").timestamp()
|
||||
except ValueError:
|
||||
continue
|
||||
except OSError:
|
||||
logger.debug(f"OSError: Ignoring unprocessable date in entry: {date_in_entry_string}")
|
||||
continue
|
||||
self.date_to_entry_ids[date_in_entry].add(id)
|
||||
def extract_dates(self, content):
|
||||
pattern_matched_dates = re.findall(self.raw_date_regex, content)
|
||||
|
||||
# Filter down to valid dates
|
||||
valid_dates = []
|
||||
for date_str in pattern_matched_dates:
|
||||
try:
|
||||
valid_dates.append(datetime.strptime(date_str, "%Y-%m-%d"))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return valid_dates
|
||||
|
||||
def get_filter_terms(self, query: str) -> List[str]:
|
||||
"Get all filter terms in query"
|
||||
return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)]
|
||||
|
||||
def get_query_date_range(self, query) -> List:
|
||||
with timer("Extract date range to filter from query", logger):
|
||||
query_daterange = self.extract_date_range(query)
|
||||
|
||||
return query_daterange
|
||||
|
||||
def defilter(self, query):
|
||||
# remove date range filter from query
|
||||
query = re.sub(rf"\s+{self.date_regex}", " ", query)
|
||||
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
|
||||
return query
|
||||
|
||||
def apply(self, query, entries):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
with timer("Extract date range to filter from query", logger):
|
||||
query_daterange = self.extract_date_range(query)
|
||||
|
||||
# if no date in query, return all entries
|
||||
if query_daterange == []:
|
||||
return query, set(range(len(entries)))
|
||||
|
||||
query = self.defilter(query)
|
||||
|
||||
# return results from cache if exists
|
||||
cache_key = tuple(query_daterange)
|
||||
if cache_key in self.cache:
|
||||
logger.debug(f"Return date filter results from cache")
|
||||
entries_to_include = self.cache[cache_key]
|
||||
return query, entries_to_include
|
||||
|
||||
if not self.date_to_entry_ids:
|
||||
self.load(entries)
|
||||
|
||||
# find entries containing any dates that fall with date range specified in query
|
||||
with timer("Mark entries satisfying filter", logger):
|
||||
entries_to_include = set()
|
||||
for date_in_entry in self.date_to_entry_ids.keys():
|
||||
# Check if date in entry is within date range specified in query
|
||||
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
||||
entries_to_include |= self.date_to_entry_ids[date_in_entry]
|
||||
|
||||
# cache results
|
||||
self.cache[cache_key] = entries_to_include
|
||||
|
||||
return query, entries_to_include
|
||||
|
||||
def extract_date_range(self, query):
|
||||
# find date range filter in query
|
||||
date_range_matches = re.findall(self.date_regex, query)
|
||||
|
@ -138,6 +108,15 @@ class DateFilter(BaseFilter):
|
|||
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
|
||||
return []
|
||||
else:
|
||||
# If the first element is 0, replace it with None
|
||||
|
||||
if effective_date_range[0] == 0:
|
||||
effective_date_range[0] = None
|
||||
|
||||
# If the second element is inf, replace it with None
|
||||
if effective_date_range[1] == inf:
|
||||
effective_date_range[1] = None
|
||||
|
||||
return effective_date_range
|
||||
|
||||
def parse(self, date_str, relative_base=None):
|
||||
|
|
|
@ -21,62 +21,13 @@ class FileFilter(BaseFilter):
|
|||
self.file_to_entry_map = defaultdict(set)
|
||||
self.cache = LRU()
|
||||
|
||||
def load(self, entries, *args, **kwargs):
|
||||
with timer("Created file filter index", logger):
|
||||
for id, entry in enumerate(entries):
|
||||
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
|
||||
|
||||
def get_filter_terms(self, query: str) -> List[str]:
|
||||
"Get all filter terms in query"
|
||||
return [f'file:"{term}"' for term in re.findall(self.file_filter_regex, query)]
|
||||
return [f"{self.convert_to_regex(term)}" for term in re.findall(self.file_filter_regex, query)]
|
||||
|
||||
def convert_to_regex(self, file_filter: str) -> str:
|
||||
"Convert file filter to regex"
|
||||
return file_filter.replace(".", r"\.").replace("*", r".*")
|
||||
|
||||
def defilter(self, query: str) -> str:
|
||||
return re.sub(self.file_filter_regex, "", query).strip()
|
||||
|
||||
def apply(self, query, entries):
|
||||
# Extract file filters from raw query
|
||||
with timer("Extract files_to_search from query", logger):
|
||||
raw_files_to_search = re.findall(self.file_filter_regex, query)
|
||||
if not raw_files_to_search:
|
||||
return query, set(range(len(entries)))
|
||||
|
||||
# Convert simple file filters with no path separator into regex
|
||||
# e.g. "file:notes.org" -> "file:.*notes.org"
|
||||
files_to_search = []
|
||||
for file in sorted(raw_files_to_search):
|
||||
if "/" not in file and "\\" not in file and "*" not in file:
|
||||
files_to_search += [f"*{file}"]
|
||||
else:
|
||||
files_to_search += [file]
|
||||
|
||||
# Remove filter terms from original query
|
||||
query = self.defilter(query)
|
||||
|
||||
# Return item from cache if exists
|
||||
cache_key = tuple(files_to_search)
|
||||
if cache_key in self.cache:
|
||||
logger.debug(f"Return file filter results from cache")
|
||||
included_entry_indices = self.cache[cache_key]
|
||||
return query, included_entry_indices
|
||||
|
||||
if not self.file_to_entry_map:
|
||||
self.load(entries, regenerate=False)
|
||||
|
||||
# Mark entries that contain any blocked_words for exclusion
|
||||
with timer("Mark entries satisfying filter", logger):
|
||||
included_entry_indices = set.union(
|
||||
*[
|
||||
self.file_to_entry_map[entry_file]
|
||||
for entry_file in self.file_to_entry_map.keys()
|
||||
for search_file in files_to_search
|
||||
if fnmatch.fnmatch(entry_file, search_file)
|
||||
],
|
||||
set(),
|
||||
)
|
||||
if not included_entry_indices:
|
||||
return query, {}
|
||||
|
||||
# Cache results
|
||||
self.cache[cache_key] = included_entry_indices
|
||||
|
||||
return query, included_entry_indices
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import List
|
|||
|
||||
# Internal Packages
|
||||
from khoj.search_filter.base_filter import BaseFilter
|
||||
from khoj.utils.helpers import LRU, timer
|
||||
from khoj.utils.helpers import LRU
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -22,21 +22,6 @@ class WordFilter(BaseFilter):
|
|||
self.word_to_entry_index = defaultdict(set)
|
||||
self.cache = LRU()
|
||||
|
||||
def load(self, entries, *args, **kwargs):
|
||||
with timer("Created word filter index", logger):
|
||||
self.cache = {} # Clear cache on filter (re-)load
|
||||
entry_splitter = (
|
||||
r",|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'"
|
||||
)
|
||||
# Create map of words to entries they exist in
|
||||
for entry_index, entry in enumerate(entries):
|
||||
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
|
||||
if word == "":
|
||||
continue
|
||||
self.word_to_entry_index[word].add(entry_index)
|
||||
|
||||
return self.word_to_entry_index
|
||||
|
||||
def get_filter_terms(self, query: str) -> List[str]:
|
||||
"Get all filter terms in query"
|
||||
required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)]
|
||||
|
@ -45,47 +30,3 @@ class WordFilter(BaseFilter):
|
|||
|
||||
def defilter(self, query: str) -> str:
|
||||
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
||||
|
||||
def apply(self, query, entries):
|
||||
"Find entries containing required and not blocked words specified in query"
|
||||
# Separate natural query from required, blocked words filters
|
||||
with timer("Extract required, blocked filters from query", logger):
|
||||
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
||||
query = self.defilter(query)
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return query, set(range(len(entries)))
|
||||
|
||||
# Return item from cache if exists
|
||||
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
|
||||
if cache_key in self.cache:
|
||||
logger.debug(f"Return word filter results from cache")
|
||||
included_entry_indices = self.cache[cache_key]
|
||||
return query, included_entry_indices
|
||||
|
||||
if not self.word_to_entry_index:
|
||||
self.load(entries, regenerate=False)
|
||||
|
||||
# mark entries that contain all required_words for inclusion
|
||||
with timer("Mark entries satisfying filter", logger):
|
||||
entries_with_all_required_words = set(range(len(entries)))
|
||||
if len(required_words) > 0:
|
||||
entries_with_all_required_words = set.intersection(
|
||||
*[self.word_to_entry_index.get(word, set()) for word in required_words]
|
||||
)
|
||||
|
||||
# mark entries that contain any blocked_words for exclusion
|
||||
entries_with_any_blocked_words = set()
|
||||
if len(blocked_words) > 0:
|
||||
entries_with_any_blocked_words = set.union(
|
||||
*[self.word_to_entry_index.get(word, set()) for word in blocked_words]
|
||||
)
|
||||
|
||||
# get entries satisfying inclusion and exclusion filters
|
||||
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
|
||||
|
||||
# Cache results
|
||||
self.cache[cache_key] = included_entry_indices
|
||||
|
||||
return query, included_entry_indices
|
||||
|
|
|
@ -2,25 +2,39 @@
|
|||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Type, Union
|
||||
from typing import List, Tuple, Type, Union, Dict
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.search_filter.base_filter import BaseFilter
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
|
||||
from khoj.utils.config import TextContent, TextSearchModel
|
||||
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
|
||||
from khoj.utils.config import TextSearchModel
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
|
||||
from khoj.utils.state import SearchType
|
||||
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, Entry
|
||||
from khoj.utils.jsonl import load_jsonl
|
||||
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from database.adapters import EmbeddingsAdapters
|
||||
from database.models import KhojUser, Embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
search_type_to_embeddings_type = {
|
||||
SearchType.Org.value: Embeddings.EmbeddingsType.ORG,
|
||||
SearchType.Markdown.value: Embeddings.EmbeddingsType.MARKDOWN,
|
||||
SearchType.Plaintext.value: Embeddings.EmbeddingsType.PLAINTEXT,
|
||||
SearchType.Pdf.value: Embeddings.EmbeddingsType.PDF,
|
||||
SearchType.Github.value: Embeddings.EmbeddingsType.GITHUB,
|
||||
SearchType.Notion.value: Embeddings.EmbeddingsType.NOTION,
|
||||
SearchType.All.value: None,
|
||||
}
|
||||
|
||||
|
||||
def initialize_model(search_config: TextSearchConfig):
|
||||
"Initialize model for semantic search on text"
|
||||
|
@ -117,171 +131,102 @@ def load_embeddings(
|
|||
|
||||
|
||||
async def query(
|
||||
user: KhojUser,
|
||||
raw_query: str,
|
||||
search_model: TextSearchModel,
|
||||
content: TextContent,
|
||||
type: SearchType = SearchType.All,
|
||||
question_embedding: Union[torch.Tensor, None] = None,
|
||||
rank_results: bool = False,
|
||||
score_threshold: float = -math.inf,
|
||||
dedupe: bool = True,
|
||||
) -> Tuple[List[dict], List[Entry]]:
|
||||
"Search for entries that answer the query"
|
||||
if (
|
||||
content.entries is None
|
||||
or len(content.entries) == 0
|
||||
or content.corpus_embeddings is None
|
||||
or len(content.corpus_embeddings) == 0
|
||||
):
|
||||
return [], []
|
||||
|
||||
query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
|
||||
file_type = search_type_to_embeddings_type[type.value]
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
|
||||
|
||||
# If no entries left after filtering, return empty results
|
||||
if entries is None or len(entries) == 0:
|
||||
return [], []
|
||||
# If query only had filters it'll be empty now. So short-circuit and return results.
|
||||
if query.strip() == "":
|
||||
hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)]
|
||||
return hits, entries
|
||||
query = raw_query
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
if question_embedding is None:
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
question_embedding = state.embeddings_model.embed_query(query)
|
||||
|
||||
# Find relevant entries for the query
|
||||
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
|
||||
top_k = 10
|
||||
with timer("Search Time", logger, state.device):
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
|
||||
hits = EmbeddingsAdapters.search_with_embeddings(
|
||||
user=user,
|
||||
embeddings=question_embedding,
|
||||
max_results=top_k,
|
||||
file_type_filter=file_type,
|
||||
raw_query=raw_query,
|
||||
).all()
|
||||
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
||||
|
||||
return hits
|
||||
|
||||
|
||||
def collate_results(hits, dedupe=True):
|
||||
hit_ids = set()
|
||||
for hit in hits:
|
||||
if dedupe and hit.corpus_id in hit_ids:
|
||||
continue
|
||||
|
||||
else:
|
||||
hit_ids.add(hit.corpus_id)
|
||||
yield SearchResponse.parse_obj(
|
||||
{
|
||||
"entry": hit.raw,
|
||||
"score": hit.distance,
|
||||
"additional": {
|
||||
"file": hit.file_path,
|
||||
"compiled": hit.compiled,
|
||||
"heading": hit.heading,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def rerank_and_sort_results(hits, query):
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results and search_model.cross_encoder:
|
||||
hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
|
||||
hits = cross_encoder_score(query, hits)
|
||||
|
||||
# Filter results by score threshold
|
||||
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
|
||||
# Sort results by cross-encoder score followed by bi-encoder score
|
||||
hits = sort_results(rank_results=True, hits=hits)
|
||||
|
||||
# Order results by cross-encoder score followed by bi-encoder score
|
||||
hits = sort_results(rank_results, hits)
|
||||
|
||||
# Deduplicate entries by raw entry text before showing to users
|
||||
if dedupe:
|
||||
hits = deduplicate_results(entries, hits)
|
||||
|
||||
return hits, entries
|
||||
|
||||
|
||||
def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]:
|
||||
return [
|
||||
SearchResponse.parse_obj(
|
||||
{
|
||||
"entry": entries[hit["corpus_id"]].raw,
|
||||
"score": f"{hit.get('cross-score') or hit.get('score')}",
|
||||
"additional": {
|
||||
"file": entries[hit["corpus_id"]].file,
|
||||
"compiled": entries[hit["corpus_id"]].compiled,
|
||||
"heading": entries[hit["corpus_id"]].heading,
|
||||
},
|
||||
}
|
||||
)
|
||||
for hit in hits[0:count]
|
||||
]
|
||||
return hits
|
||||
|
||||
|
||||
def setup(
|
||||
text_to_jsonl: Type[TextToJsonl],
|
||||
text_to_jsonl: Type[TextEmbeddings],
|
||||
files: dict[str, str],
|
||||
config: TextConfigBase,
|
||||
bi_encoder: BaseEncoder,
|
||||
regenerate: bool,
|
||||
filters: List[BaseFilter] = [],
|
||||
normalize: bool = True,
|
||||
full_corpus: bool = True,
|
||||
) -> TextContent:
|
||||
# Map notes in text files to (compressed) JSONL formatted file
|
||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||
previous_entries = []
|
||||
if config.compressed_jsonl.exists() and not regenerate:
|
||||
previous_entries = extract_entries(config.compressed_jsonl)
|
||||
entries_with_indices = text_to_jsonl(config).process(
|
||||
previous_entries=previous_entries, files=files, full_corpus=full_corpus
|
||||
)
|
||||
|
||||
# Extract Updated Entries
|
||||
entries = extract_entries(config.compressed_jsonl)
|
||||
if is_none_or_empty(entries):
|
||||
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
|
||||
raise ValueError(
|
||||
f"No valid entries found in specified configuration: {config_params}, with files: {files.keys()}"
|
||||
user: KhojUser = None,
|
||||
config=None,
|
||||
) -> None:
|
||||
if config:
|
||||
num_new_embeddings, num_deleted_embeddings = text_to_jsonl(config).process(
|
||||
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
|
||||
)
|
||||
else:
|
||||
num_new_embeddings, num_deleted_embeddings = text_to_jsonl().process(
|
||||
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
|
||||
)
|
||||
|
||||
# Compute or Load Embeddings
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = compute_embeddings(
|
||||
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize
|
||||
file_names = [file_name for file_name in files]
|
||||
|
||||
logger.info(
|
||||
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
|
||||
)
|
||||
|
||||
for filter in filters:
|
||||
filter.load(entries, regenerate=regenerate)
|
||||
|
||||
return TextContent(entries, corpus_embeddings, filters)
|
||||
|
||||
|
||||
def load(
|
||||
config: TextConfigBase,
|
||||
filters: List[BaseFilter] = [],
|
||||
) -> TextContent:
|
||||
# Map notes in text files to (compressed) JSONL formatted file
|
||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||
entries = extract_entries(config.compressed_jsonl)
|
||||
|
||||
# Compute or Load Embeddings
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = load_embeddings(config.embeddings_file)
|
||||
|
||||
for filter in filters:
|
||||
filter.load(entries, regenerate=False)
|
||||
|
||||
return TextContent(entries, corpus_embeddings, filters)
|
||||
|
||||
|
||||
def apply_filters(
|
||||
query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]
|
||||
) -> Tuple[str, List[Entry], torch.Tensor]:
|
||||
"""Filter query, entries and embeddings before semantic search"""
|
||||
|
||||
with timer("Total Filter Time", logger, state.device):
|
||||
included_entry_indices = set(range(len(entries)))
|
||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||
for filter in filters_in_query:
|
||||
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
||||
|
||||
# Get entries (and associated embeddings) satisfying all filters
|
||||
if not included_entry_indices:
|
||||
return "", [], torch.tensor([], device=state.device)
|
||||
else:
|
||||
entries = [entries[id] for id in included_entry_indices]
|
||||
corpus_embeddings = torch.index_select(
|
||||
corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)
|
||||
)
|
||||
|
||||
return query, entries, corpus_embeddings
|
||||
|
||||
|
||||
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]:
|
||||
def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:
|
||||
"""Score all retrieved entries using the cross-encoder"""
|
||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_inp = [[query, entries[hit["corpus_id"]].compiled] for hit in hits]
|
||||
cross_scores = cross_encoder.predict(cross_inp)
|
||||
cross_scores = state.cross_encoder_model.predict(query, hits)
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]["cross-score"] = cross_scores[idx]
|
||||
hits[idx]["cross_score"] = cross_scores[idx]
|
||||
|
||||
return hits
|
||||
|
||||
|
@ -291,23 +236,5 @@ def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
|
|||
with timer("Rank Time", logger, state.device):
|
||||
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x["cross-score"], reverse=True) # sort by cross-encoder score
|
||||
return hits
|
||||
|
||||
|
||||
def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]:
|
||||
"""Deduplicate entries by raw entry text before showing to users
|
||||
Compiled entries are split by max tokens supported by ML models.
|
||||
This can result in duplicate hits, entries shown to user."""
|
||||
|
||||
with timer("Deduplication Time", logger, state.device):
|
||||
seen, original_hits_count = set(), len(hits)
|
||||
hits = [
|
||||
hit
|
||||
for hit in hits
|
||||
if entries[hit["corpus_id"]].raw not in seen and not seen.add(entries[hit["corpus_id"]].raw) # type: ignore[func-returns-value]
|
||||
]
|
||||
duplicate_hits = original_hits_count - len(hits)
|
||||
|
||||
logger.debug(f"Removed {duplicate_hits} duplicates")
|
||||
hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
|
||||
return hits
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import argparse
|
||||
import pathlib
|
||||
from importlib.metadata import version
|
||||
import os
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.helpers import resolve_absolute_path
|
||||
|
@ -34,6 +35,12 @@ def cli(args=None):
|
|||
)
|
||||
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
|
||||
parser.add_argument("--demo", action="store_true", default=False, help="Run Khoj in demo mode")
|
||||
parser.add_argument(
|
||||
"--anonymous-mode",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run Khoj in anonymous mode. This does not require any login for connecting users.",
|
||||
)
|
||||
|
||||
args = parser.parse_args(args)
|
||||
|
||||
|
@ -51,6 +58,8 @@ def cli(args=None):
|
|||
else:
|
||||
args = run_migrations(args)
|
||||
args.config = parse_config_from_file(args.config_file)
|
||||
if os.environ.get("DEBUG"):
|
||||
args.config.app.should_log_telemetry = False
|
||||
|
||||
return args
|
||||
|
||||
|
|
|
@ -41,9 +41,7 @@ class ProcessorType(str, Enum):
|
|||
|
||||
@dataclass
|
||||
class TextContent:
|
||||
entries: List[Entry]
|
||||
corpus_embeddings: torch.Tensor
|
||||
filters: List[BaseFilter]
|
||||
enabled: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -67,21 +65,13 @@ class ImageSearchModel:
|
|||
|
||||
@dataclass
|
||||
class ContentIndex:
|
||||
org: Optional[TextContent] = None
|
||||
markdown: Optional[TextContent] = None
|
||||
pdf: Optional[TextContent] = None
|
||||
github: Optional[TextContent] = None
|
||||
notion: Optional[TextContent] = None
|
||||
image: Optional[ImageContent] = None
|
||||
plaintext: Optional[TextContent] = None
|
||||
plugins: Optional[Dict[str, TextContent]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchModels:
|
||||
text_search: Optional[TextSearchModel] = None
|
||||
image_search: Optional[ImageSearchModel] = None
|
||||
plugin_search: Optional[Dict[str, TextSearchModel]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -5,6 +5,7 @@ web_directory = app_root_directory / "khoj/interface/web/"
|
|||
empty_escape_sequences = "\n|\r|\t| "
|
||||
app_env_filepath = "~/.khoj/env"
|
||||
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
||||
content_directory = "~/.khoj/content/"
|
||||
|
||||
empty_config = {
|
||||
"content-type": {
|
||||
|
|
|
@ -5,29 +5,39 @@ from typing import Optional
|
|||
from bs4 import BeautifulSoup
|
||||
|
||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from khoj.utils.rawconfig import TextContentConfig, ContentConfig
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
from khoj.utils.config import SearchType
|
||||
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def collect_files(config: ContentConfig, search_type: Optional[SearchType] = SearchType.All):
|
||||
def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict:
|
||||
files = {}
|
||||
|
||||
if config is None:
|
||||
return files
|
||||
|
||||
if search_type == SearchType.All or search_type == SearchType.Org:
|
||||
files["org"] = get_org_files(config.org) if config.org else {}
|
||||
org_config = LocalOrgConfig.objects.filter(user=user).first()
|
||||
files["org"] = get_org_files(construct_config_from_db(org_config)) if org_config else {}
|
||||
if search_type == SearchType.All or search_type == SearchType.Markdown:
|
||||
files["markdown"] = get_markdown_files(config.markdown) if config.markdown else {}
|
||||
markdown_config = LocalMarkdownConfig.objects.filter(user=user).first()
|
||||
files["markdown"] = get_markdown_files(construct_config_from_db(markdown_config)) if markdown_config else {}
|
||||
if search_type == SearchType.All or search_type == SearchType.Plaintext:
|
||||
files["plaintext"] = get_plaintext_files(config.plaintext) if config.plaintext else {}
|
||||
plaintext_config = LocalPlaintextConfig.objects.filter(user=user).first()
|
||||
files["plaintext"] = get_plaintext_files(construct_config_from_db(plaintext_config)) if plaintext_config else {}
|
||||
if search_type == SearchType.All or search_type == SearchType.Pdf:
|
||||
files["pdf"] = get_pdf_files(config.pdf) if config.pdf else {}
|
||||
pdf_config = LocalPdfConfig.objects.filter(user=user).first()
|
||||
files["pdf"] = get_pdf_files(construct_config_from_db(pdf_config)) if pdf_config else {}
|
||||
return files
|
||||
|
||||
|
||||
def construct_config_from_db(db_config) -> TextContentConfig:
|
||||
return TextContentConfig(
|
||||
input_files=db_config.input_files,
|
||||
input_filter=db_config.input_filter,
|
||||
index_heading_entries=db_config.index_heading_entries,
|
||||
)
|
||||
|
||||
|
||||
def get_plaintext_files(config: TextContentConfig) -> dict[str, str]:
|
||||
def is_plaintextfile(file: str):
|
||||
"Check if file is plaintext file"
|
||||
|
|
|
@ -209,10 +209,12 @@ def log_telemetry(
|
|||
if not app_config or not app_config.should_log_telemetry:
|
||||
return []
|
||||
|
||||
if properties.get("server_id") is None:
|
||||
properties["server_id"] = get_server_id()
|
||||
|
||||
# Populate telemetry data to log
|
||||
request_body = {
|
||||
"telemetry_type": telemetry_type,
|
||||
"server_id": get_server_id(),
|
||||
"server_version": version("khoj-assistant"),
|
||||
"os": platform.system(),
|
||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
# System Packages
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Union, Any
|
||||
from typing import List, Dict, Optional
|
||||
import uuid
|
||||
|
||||
# External Packages
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty
|
||||
from khoj.utils.helpers import to_snake_case_from_dash
|
||||
|
||||
|
||||
class ConfigBase(BaseModel):
|
||||
|
@ -27,7 +28,7 @@ class TextConfigBase(ConfigBase):
|
|||
embeddings_file: Path
|
||||
|
||||
|
||||
class TextContentConfig(TextConfigBase):
|
||||
class TextContentConfig(ConfigBase):
|
||||
input_files: Optional[List[Path]]
|
||||
input_filter: Optional[List[str]]
|
||||
index_heading_entries: Optional[bool] = False
|
||||
|
@ -39,12 +40,12 @@ class GithubRepoConfig(ConfigBase):
|
|||
branch: Optional[str] = "master"
|
||||
|
||||
|
||||
class GithubContentConfig(TextConfigBase):
|
||||
class GithubContentConfig(ConfigBase):
|
||||
pat_token: str
|
||||
repos: List[GithubRepoConfig]
|
||||
|
||||
|
||||
class NotionContentConfig(TextConfigBase):
|
||||
class NotionContentConfig(ConfigBase):
|
||||
token: str
|
||||
|
||||
|
||||
|
@ -63,7 +64,6 @@ class ContentConfig(ConfigBase):
|
|||
pdf: Optional[TextContentConfig]
|
||||
plaintext: Optional[TextContentConfig]
|
||||
github: Optional[GithubContentConfig]
|
||||
plugins: Optional[Dict[str, TextContentConfig]]
|
||||
notion: Optional[NotionContentConfig]
|
||||
|
||||
|
||||
|
@ -122,7 +122,8 @@ class FullConfig(ConfigBase):
|
|||
|
||||
class SearchResponse(ConfigBase):
|
||||
entry: str
|
||||
score: str
|
||||
score: float
|
||||
cross_score: Optional[float]
|
||||
additional: Optional[dict]
|
||||
|
||||
|
||||
|
@ -131,14 +132,21 @@ class Entry:
|
|||
compiled: str
|
||||
heading: Optional[str]
|
||||
file: Optional[str]
|
||||
corpus_id: str
|
||||
|
||||
def __init__(
|
||||
self, raw: str = None, compiled: str = None, heading: Optional[str] = None, file: Optional[str] = None
|
||||
self,
|
||||
raw: str = None,
|
||||
compiled: str = None,
|
||||
heading: Optional[str] = None,
|
||||
file: Optional[str] = None,
|
||||
corpus_id: uuid.UUID = None,
|
||||
):
|
||||
self.raw = raw
|
||||
self.compiled = compiled
|
||||
self.heading = heading
|
||||
self.file = file
|
||||
self.corpus_id = str(corpus_id)
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.__dict__, ensure_ascii=False)
|
||||
|
@ -153,4 +161,5 @@ class Entry:
|
|||
compiled=dictionary["compiled"],
|
||||
file=dictionary.get("file", None),
|
||||
heading=dictionary.get("heading", None),
|
||||
corpus_id=dictionary.get("corpus_id", None),
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import threading
|
||||
from typing import List, Dict
|
||||
from packaging import version
|
||||
from collections import defaultdict
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
@ -12,10 +13,13 @@ from khoj.utils import config as utils_config
|
|||
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
|
||||
from khoj.utils.helpers import LRU
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
|
||||
|
||||
# Application Global State
|
||||
config = FullConfig()
|
||||
search_models = SearchModels()
|
||||
embeddings_model = EmbeddingsModel()
|
||||
cross_encoder_model = CrossEncoderModel()
|
||||
content_index = ContentIndex()
|
||||
processor_config = ProcessorConfigModel()
|
||||
config_file: Path = None
|
||||
|
@ -23,14 +27,14 @@ verbose: int = 0
|
|||
host: str = None
|
||||
port: int = None
|
||||
cli_args: List[str] = None
|
||||
query_cache = LRU()
|
||||
query_cache: Dict[str, LRU] = defaultdict(LRU)
|
||||
config_lock = threading.Lock()
|
||||
chat_lock = threading.Lock()
|
||||
SearchType = utils_config.SearchType
|
||||
telemetry: List[Dict[str, str]] = []
|
||||
previous_query: str = None
|
||||
demo: bool = False
|
||||
khoj_version: str = None
|
||||
anonymous_mode: bool = False
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Use CUDA GPU
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
# External Packages
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from fastapi.testclient import TestClient
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi import FastAPI
|
||||
import factory
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from app.main import app
|
||||
from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils.config import SearchModels
|
||||
|
@ -22,8 +26,6 @@ from khoj.utils.rawconfig import (
|
|||
OpenAIProcessorConfig,
|
||||
ProcessorConfig,
|
||||
TextContentConfig,
|
||||
GithubContentConfig,
|
||||
GithubRepoConfig,
|
||||
ImageContentConfig,
|
||||
SearchConfig,
|
||||
TextSearchConfig,
|
||||
|
@ -31,11 +33,31 @@ from khoj.utils.rawconfig import (
|
|||
)
|
||||
from khoj.utils import state, fs_syncer
|
||||
from khoj.routers.indexer import configure_content
|
||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from database.models import (
|
||||
LocalOrgConfig,
|
||||
LocalMarkdownConfig,
|
||||
LocalPlaintextConfig,
|
||||
LocalPdfConfig,
|
||||
GithubConfig,
|
||||
KhojUser,
|
||||
GithubRepoConfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_db_access_for_all_tests(db):
|
||||
pass
|
||||
|
||||
|
||||
class UserFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = KhojUser
|
||||
|
||||
username = factory.Faker("name")
|
||||
email = factory.Faker("email")
|
||||
password = factory.Faker("password")
|
||||
uuid = factory.Faker("uuid4")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -67,17 +89,28 @@ def search_config() -> SearchConfig:
|
|||
return search_config
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_user():
|
||||
return UserFactory()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_models(search_config: SearchConfig):
|
||||
search_models = SearchModels()
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
return search_models
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
|
||||
@pytest.fixture
|
||||
def anyio_backend():
|
||||
return "asyncio"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture(scope="function")
|
||||
def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser):
|
||||
content_dir = tmp_path_factory.mktemp("content")
|
||||
|
||||
# Generate Image Embeddings from Test Images
|
||||
|
@ -92,94 +125,45 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
|
|||
|
||||
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
|
||||
|
||||
# Generate Notes Embeddings from Test Notes
|
||||
content_config.org = TextContentConfig(
|
||||
LocalOrgConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/org/*.org"],
|
||||
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
|
||||
index_heading_entries=False,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
get_sample_data("org"),
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
content_config.plugins = {
|
||||
"plugin1": TextContentConfig(
|
||||
input_files=[content_dir.joinpath("notes.jsonl.gz")],
|
||||
input_filter=None,
|
||||
compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
|
||||
)
|
||||
}
|
||||
text_search.setup(OrgToJsonl, get_sample_data("org"), regenerate=False, user=default_user)
|
||||
|
||||
if os.getenv("GITHUB_PAT_TOKEN"):
|
||||
content_config.github = GithubContentConfig(
|
||||
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
|
||||
repos=[
|
||||
GithubRepoConfig(
|
||||
owner="khoj-ai",
|
||||
name="lantern",
|
||||
branch="master",
|
||||
)
|
||||
],
|
||||
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
|
||||
GithubConfig.objects.create(
|
||||
pat_token=os.getenv("GITHUB_PAT_TOKEN"),
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
content_config.plaintext = TextContentConfig(
|
||||
GithubRepoConfig.objects.create(
|
||||
owner="khoj-ai",
|
||||
name="lantern",
|
||||
branch="master",
|
||||
github_config=GithubConfig.objects.get(user=default_user),
|
||||
)
|
||||
|
||||
LocalPlaintextConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
|
||||
compressed_jsonl=content_dir.joinpath("plaintext.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("plaintext_embeddings.pt"),
|
||||
)
|
||||
|
||||
content_config.github = GithubContentConfig(
|
||||
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
|
||||
repos=[
|
||||
GithubRepoConfig(
|
||||
owner="khoj-ai",
|
||||
name="lantern",
|
||||
branch="master",
|
||||
)
|
||||
],
|
||||
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
|
||||
)
|
||||
|
||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
text_search.setup(
|
||||
JsonlToJsonl,
|
||||
None,
|
||||
content_config.plugins["plugin1"],
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
return content_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def md_content_config(tmp_path_factory):
|
||||
content_dir = tmp_path_factory.mktemp("content")
|
||||
|
||||
# Generate Embeddings for Markdown Content
|
||||
content_config = ContentConfig()
|
||||
content_config.markdown = TextContentConfig(
|
||||
def md_content_config():
|
||||
markdown_config = LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/markdown/*.markdown"],
|
||||
compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
|
||||
)
|
||||
|
||||
return content_config
|
||||
return markdown_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -220,19 +204,20 @@ def processor_config_offline_chat(tmp_path_factory):
|
|||
@pytest.fixture(scope="session")
|
||||
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
|
||||
# Initialize app state
|
||||
state.config.content_type = md_content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
|
||||
# Index Markdown Content for Search
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
all_files = fs_syncer.collect_files(state.config.content_type)
|
||||
all_files = fs_syncer.collect_files()
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models
|
||||
)
|
||||
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(processor_config)
|
||||
state.anonymous_mode = True
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
|
@ -241,33 +226,45 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
|
|||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
|
||||
def fastapi_app():
|
||||
app = FastAPI()
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(
|
||||
content_config: ContentConfig,
|
||||
search_config: SearchConfig,
|
||||
processor_config: ProcessorConfig,
|
||||
default_user: KhojUser,
|
||||
):
|
||||
state.config.content_type = content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
|
||||
# These lines help us Mock the Search models for these search types
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
state.content_index.org = text_search.setup(
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
get_sample_data("org"),
|
||||
content_config.org,
|
||||
state.search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
user=default_user,
|
||||
)
|
||||
state.content_index.image = image_search.setup(
|
||||
content_config.image, state.search_models.image_search, regenerate=False
|
||||
)
|
||||
state.content_index.plaintext = text_search.setup(
|
||||
text_search.setup(
|
||||
PlaintextToJsonl,
|
||||
get_sample_data("plaintext"),
|
||||
content_config.plaintext,
|
||||
state.search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
state.processor_config = configure_processor(processor_config)
|
||||
state.anonymous_mode = True
|
||||
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
|
@ -288,7 +285,6 @@ def client_offline_chat(
|
|||
state.SearchType = configure_search_types(state.config)
|
||||
|
||||
# Index Markdown Content for Search
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
all_files = fs_syncer.collect_files(state.config.content_type)
|
||||
|
@ -298,6 +294,7 @@ def client_offline_chat(
|
|||
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(processor_config_offline_chat)
|
||||
state.anonymous_mode = True
|
||||
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
|
@ -306,9 +303,11 @@ def client_offline_chat(
|
|||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def new_org_file(content_config: ContentConfig):
|
||||
def new_org_file(default_user: KhojUser, content_config: ContentConfig):
|
||||
# Setup
|
||||
new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org"
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
input_filters = org_config.input_filter
|
||||
new_org_file = Path(input_filters[0]).parent / "new_file.org"
|
||||
new_org_file.touch()
|
||||
|
||||
yield new_org_file
|
||||
|
@ -319,11 +318,9 @@ def new_org_file(content_config: ContentConfig):
|
|||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path):
|
||||
new_org_config = deepcopy(content_config.org)
|
||||
new_org_config.input_files = [f"{new_org_file}"]
|
||||
new_org_config.input_filter = None
|
||||
return new_org_config
|
||||
def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
|
||||
LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None)
|
||||
return LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
|
11
tests/data/config.yml
vendored
11
tests/data/config.yml
vendored
|
@ -9,17 +9,6 @@ content-type:
|
|||
input-filter:
|
||||
- '*.org'
|
||||
- ~/notes/*.org
|
||||
plugins:
|
||||
content_plugin_1:
|
||||
compressed-jsonl: content_plugin_1.jsonl.gz
|
||||
embeddings-file: content_plugin_1_embeddings.pt
|
||||
input-files:
|
||||
- content_plugin_1_new.jsonl.gz
|
||||
content_plugin_2:
|
||||
compressed-jsonl: content_plugin_2.jsonl.gz
|
||||
embeddings-file: content_plugin_2_embeddings.pt
|
||||
input-filter:
|
||||
- '*2_new.jsonl.gz'
|
||||
enable-offline-chat: false
|
||||
search-type:
|
||||
asymmetric:
|
||||
|
|
|
@ -48,14 +48,3 @@ def test_cli_config_from_file():
|
|||
Path("~/first_from_config.org"),
|
||||
Path("~/second_from_config.org"),
|
||||
]
|
||||
assert len(actual_args.config.content_type.plugins.keys()) == 2
|
||||
assert actual_args.config.content_type.plugins["content_plugin_1"].input_files == [
|
||||
Path("content_plugin_1_new.jsonl.gz")
|
||||
]
|
||||
assert actual_args.config.content_type.plugins["content_plugin_2"].input_filter == ["*2_new.jsonl.gz"]
|
||||
assert actual_args.config.content_type.plugins["content_plugin_1"].compressed_jsonl == Path(
|
||||
"content_plugin_1.jsonl.gz"
|
||||
)
|
||||
assert actual_args.config.content_type.plugins["content_plugin_2"].embeddings_file == Path(
|
||||
"content_plugin_2_embeddings.pt"
|
||||
)
|
||||
|
|
|
@ -2,22 +2,21 @@
|
|||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from urllib.parse import quote
|
||||
|
||||
import pytest
|
||||
|
||||
# External Packages
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
# Internal Packages
|
||||
from app.main import app
|
||||
from khoj.configure import configure_routes, configure_search_types
|
||||
from khoj.utils import state
|
||||
from khoj.utils.state import search_models, content_index, config
|
||||
from khoj.search_type import text_search, image_search
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from database.models import KhojUser
|
||||
from database.adapters import EmbeddingsAdapters
|
||||
|
||||
|
||||
# Test
|
||||
|
@ -35,7 +34,7 @@ def test_search_with_invalid_content_type(client):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_search_with_valid_content_type(client):
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plugin1"]:
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion"]:
|
||||
# Act
|
||||
response = client.get(f"/api/search?q=random&t={content_type}")
|
||||
# Assert
|
||||
|
@ -75,7 +74,7 @@ def test_index_update(client):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_regenerate_with_valid_content_type(client):
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]:
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]:
|
||||
# Arrange
|
||||
files = get_sample_files_data()
|
||||
headers = {"x-api-key": "secret"}
|
||||
|
@ -102,60 +101,42 @@ def test_regenerate_with_github_fails_without_pat(client):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.skip(reason="Flaky test on parallel test runs")
|
||||
def test_get_configured_types_via_api(client):
|
||||
def test_get_configured_types_via_api(client, sample_org_data):
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||
|
||||
enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["all", "org", "image", "plaintext", "plugin1"]
|
||||
assert list(enabled_types) == ["org"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_get_configured_types_with_only_plugin_content_config(content_config):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data):
|
||||
# Arrange
|
||||
config.content_type = ContentConfig()
|
||||
config.content_type.plugins = content_config.plugins
|
||||
state.SearchType = configure_search_types(config)
|
||||
|
||||
configure_routes(app)
|
||||
client = TestClient(app)
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["all", "plugin1"]
|
||||
assert response.json() == ["all", "org", "markdown", "image"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_get_configured_types_with_no_plugin_content_config(content_config):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
||||
# Arrange
|
||||
config.content_type = content_config
|
||||
config.content_type.plugins = None
|
||||
state.SearchType = configure_search_types(config)
|
||||
original_config = state.config.content_type
|
||||
state.config.content_type = None
|
||||
|
||||
configure_routes(app)
|
||||
client = TestClient(app)
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "plugin1" not in response.json()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_get_configured_types_with_no_content_config():
|
||||
# Arrange
|
||||
config.content_type = ContentConfig()
|
||||
state.SearchType = configure_search_types(config)
|
||||
|
||||
configure_routes(app)
|
||||
client = TestClient(app)
|
||||
configure_routes(fastapi_app)
|
||||
client = TestClient(fastapi_app)
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
|
@ -164,6 +145,9 @@ def test_get_configured_types_with_no_content_config():
|
|||
assert response.status_code == 200
|
||||
assert response.json() == ["all"]
|
||||
|
||||
# Restore
|
||||
state.config.content_type = original_config
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
|
@ -192,12 +176,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search(client, search_config: SearchConfig, sample_org_data):
|
||||
# Arrange
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
|
@ -211,19 +193,15 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_only_filters(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
):
|
||||
# Arrange
|
||||
filters = [WordFilter(), FileFilter()]
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
sample_org_data,
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
)
|
||||
user_query = quote('+"Emacs" file:"*.org"')
|
||||
|
||||
|
@ -238,15 +216,10 @@ def test_notes_search_with_only_filters(
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_include_filter(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_include_filter(client, sample_org_data):
|
||||
# Arrange
|
||||
filters = [WordFilter()]
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search, regenerate=False, filters=filters
|
||||
)
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||
user_query = quote('How to git install application? +"Emacs"')
|
||||
|
||||
# Act
|
||||
|
@ -260,19 +233,13 @@ def test_notes_search_with_include_filter(
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_exclude_filter(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_exclude_filter(client, sample_org_data):
|
||||
# Arrange
|
||||
filters = [WordFilter()]
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
sample_org_data,
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
)
|
||||
user_query = quote('How to git install application? -"clone"')
|
||||
|
||||
|
@ -286,6 +253,22 @@ def test_notes_search_with_exclude_filter(
|
|||
assert "clone" not in search_result
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# assert actual response has no data as the default_user is different from the user making the query (anonymous)
|
||||
assert len(response.json()) == 0
|
||||
|
||||
|
||||
def get_sample_files_data():
|
||||
return {
|
||||
"files": ("path/to/filename.org", "* practicing piano", "text/org"),
|
||||
|
|
|
@ -1,53 +1,12 @@
|
|||
# Standard Packages
|
||||
import re
|
||||
from datetime import datetime
|
||||
from math import inf
|
||||
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
|
||||
def test_date_filter():
|
||||
entries = [
|
||||
Entry(compiled="Entry with no date", raw="Entry with no date"),
|
||||
Entry(compiled="April Fools entry: 1984-04-01", raw="April Fools entry: 1984-04-01"),
|
||||
Entry(compiled="Entry with date:1984-04-02", raw="Entry with date:1984-04-02"),
|
||||
]
|
||||
|
||||
q_with_no_date_filter = "head tail"
|
||||
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2}
|
||||
|
||||
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == set()
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {2}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {1}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {2}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {1, 2}
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
|
||||
|
@ -56,8 +15,8 @@ def test_extract_date_range():
|
|||
datetime(1984, 1, 5, 0, 0, 0).timestamp(),
|
||||
datetime(1984, 1, 7, 0, 0, 0).timestamp(),
|
||||
]
|
||||
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
|
||||
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [None, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), None]
|
||||
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
|
||||
datetime(1984, 1, 1, 0, 0, 0).timestamp(),
|
||||
datetime(1984, 1, 2, 0, 0, 0).timestamp(),
|
||||
|
|
|
@ -6,97 +6,73 @@ from khoj.utils.rawconfig import Entry
|
|||
def test_no_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = "head tail"
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == False
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_file_filter_with_non_existent_file():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"nonexistent.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {}
|
||||
|
||||
|
||||
def test_single_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"file 1.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_file_filter_with_partial_match():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"1.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_file_filter_with_regex_match():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"*.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_multiple_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_get_file_filter_terms():
|
||||
|
@ -108,7 +84,7 @@ def test_get_file_filter_terms():
|
|||
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
|
||||
|
||||
# Assert
|
||||
assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"']
|
||||
assert filter_terms == ["file 1\\.org", "/path/to/dir/.*\\.org"]
|
||||
|
||||
|
||||
def arrange_content():
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
# Internal Packages
|
||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
def test_process_entries_from_single_input_jsonl(tmp_path):
|
||||
"Convert multiple jsonl entries from single file to entries."
|
||||
# Arrange
|
||||
input_jsonl = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}
|
||||
{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}
|
||||
"""
|
||||
input_jsonl_file = create_file(tmp_path, input_jsonl)
|
||||
|
||||
# Act
|
||||
# Process Each Entry from All Notes Files
|
||||
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file])
|
||||
entries = list(map(Entry.from_dict, input_jsons))
|
||||
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert output_jsonl == input_jsonl
|
||||
|
||||
|
||||
def test_process_entries_from_multiple_input_jsonls(tmp_path):
|
||||
"Convert multiple jsonl entries from single file to entries."
|
||||
# Arrange
|
||||
input_jsonl_1 = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}"""
|
||||
input_jsonl_2 = """{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}"""
|
||||
input_jsonl_file_1 = create_file(tmp_path, input_jsonl_1, filename="input1.jsonl")
|
||||
input_jsonl_file_2 = create_file(tmp_path, input_jsonl_2, filename="input2.jsonl")
|
||||
|
||||
# Act
|
||||
# Process Each Entry from All Notes Files
|
||||
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file_1, input_jsonl_file_2])
|
||||
entries = list(map(Entry.from_dict, input_jsons))
|
||||
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert output_jsonl == f"{input_jsonl_1}\n{input_jsonl_2}\n"
|
||||
|
||||
|
||||
def test_get_jsonl_files(tmp_path):
|
||||
"Ensure JSONL files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
# Include via input-filter globs
|
||||
group1_file1 = create_file(tmp_path, filename="group1-file1.jsonl")
|
||||
group1_file2 = create_file(tmp_path, filename="group1-file2.jsonl")
|
||||
group2_file1 = create_file(tmp_path, filename="group2-file1.jsonl")
|
||||
group2_file2 = create_file(tmp_path, filename="group2-file2.jsonl")
|
||||
# Include via input-file field
|
||||
file1 = create_file(tmp_path, filename="notes.jsonl")
|
||||
# Not included by any filter
|
||||
create_file(tmp_path, filename="not-included-jsonl.jsonl")
|
||||
create_file(tmp_path, filename="not-included-text.txt")
|
||||
|
||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / "notes.jsonl"]
|
||||
input_filter = [tmp_path / "group1*.jsonl", tmp_path / "group2*.jsonl"]
|
||||
|
||||
# Act
|
||||
extracted_org_files = JsonlToJsonl.get_jsonl_files(input_files, input_filter)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_org_files) == 5
|
||||
assert extracted_org_files == expected_files
|
||||
|
||||
|
||||
# Helper Functions
|
||||
def create_file(tmp_path, entry=None, filename="test.jsonl"):
|
||||
jsonl_file = tmp_path / filename
|
||||
jsonl_file.touch()
|
||||
if entry:
|
||||
jsonl_file.write_text(entry)
|
||||
return jsonl_file
|
|
@ -4,7 +4,7 @@ import os
|
|||
|
||||
# Internal Packages
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.utils.fs_syncer import get_org_files
|
||||
|
@ -63,7 +63,7 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
|
|||
|
||||
# Split each entry from specified Org files by max words
|
||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||
TextToJsonl.split_entries_by_max_tokens(
|
||||
TextEmbeddings.split_entries_by_max_tokens(
|
||||
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
|
||||
)
|
||||
)
|
||||
|
@ -86,7 +86,7 @@ def test_entry_split_drops_large_words():
|
|||
|
||||
# Act
|
||||
# Split entry by max words and drop words larger than max word length
|
||||
processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0]
|
||||
processed_entry = TextEmbeddings.split_entries_by_max_tokens([entry], max_word_length=5)[0]
|
||||
|
||||
# Assert
|
||||
# "Heading" dropped from compiled version because its over the set max word limit
|
||||
|
|
|
@ -7,6 +7,7 @@ from pathlib import Path
|
|||
from khoj.utils.fs_syncer import get_plaintext_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||
from database.models import LocalPlaintextConfig, KhojUser
|
||||
|
||||
|
||||
def test_plaintext_file(tmp_path):
|
||||
|
@ -91,11 +92,12 @@ def test_get_plaintext_files(tmp_path):
|
|||
assert set(extracted_plaintext_files.keys()) == set(expected_files)
|
||||
|
||||
|
||||
def test_parse_html_plaintext_file(content_config):
|
||||
def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
|
||||
"Ensure HTML files are parsed correctly"
|
||||
# Arrange
|
||||
# Setup input-files, input-filters
|
||||
extracted_plaintext_files = get_plaintext_files(content_config.plaintext)
|
||||
config = LocalPlaintextConfig.objects.filter(user=default_user).first()
|
||||
extracted_plaintext_files = get_plaintext_files(config=config)
|
||||
|
||||
# Act
|
||||
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files)
|
||||
|
|
|
@ -3,23 +3,30 @@ import logging
|
|||
import locale
|
||||
from pathlib import Path
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.state import content_index, search_models
|
||||
from khoj.search_type import text_search
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.fs_syncer import get_org_files
|
||||
from khoj.utils.fs_syncer import get_org_files, collect_files
|
||||
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
|
||||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig):
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_missing_file_raises_error(
|
||||
org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig
|
||||
):
|
||||
# Arrange
|
||||
# Ensure file mentioned in org.input-files is missing
|
||||
single_new_file = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
@ -32,98 +39,126 @@ def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_n
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path: Path):
|
||||
@pytest.mark.django_db
|
||||
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser):
|
||||
# Arrange
|
||||
orgfile = tmp_path / "directory.org" / "file.org"
|
||||
orgfile.parent.mkdir()
|
||||
with open(orgfile, "w") as f:
|
||||
f.write("* Heading\n- List item\n")
|
||||
org_content_config = TextContentConfig(
|
||||
input_filter=[f"{tmp_path}/**/*"], compressed_jsonl="test.jsonl", embeddings_file="test.pt"
|
||||
|
||||
LocalOrgConfig.objects.create(
|
||||
input_filter=[f"{tmp_path}/**/*"],
|
||||
input_files=None,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
org_files = collect_files(user=default_user)["org"]
|
||||
|
||||
# Act
|
||||
# should not raise IsADirectoryError and return orgfile
|
||||
assert get_org_files(org_content_config) == {f"{orgfile}": "* Heading\n- List item\n"}
|
||||
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_empty_file_raises_error(
|
||||
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with pytest.raises(ValueError, match=r"^No valid entries found*"):
|
||||
text_search.setup(OrgToJsonl, data, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
verify_embeddings(0, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup(content_config, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
|
||||
# Act
|
||||
# Regenerate notes embeddings during asymmetric setup
|
||||
notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert len(notes_model.entries) == 10
|
||||
assert len(notes_model.corpus_embeddings) == 10
|
||||
assert "Deleting all embeddings for file type org" in caplog.records[1].message
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog):
|
||||
@pytest.mark.django_db
|
||||
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
caplog.set_level(logging.INFO, logger="khoj")
|
||||
|
||||
data = get_org_files(content_config.org)
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# Generate initial notes embeddings during asymmetric setup
|
||||
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
|
||||
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
assert "Creating index from scratch." in initial_logs
|
||||
assert "Creating index from scratch." not in final_logs
|
||||
assert "Deleting all embeddings for file type org" in initial_logs
|
||||
assert "Deleting all embeddings for file type org" not in final_logs
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.anyio
|
||||
async def test_text_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# @pytest.mark.asyncio
|
||||
async def test_text_search(search_config: SearchConfig):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
default_user = await KhojUser.objects.acreate(
|
||||
username="test_user", password="test_password", email="test@example.com"
|
||||
)
|
||||
# Arrange
|
||||
org_config = await LocalOrgConfig.objects.acreate(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/org/*.org"],
|
||||
index_heading_entries=False,
|
||||
user=default_user,
|
||||
)
|
||||
data = get_org_files(org_config)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
text_search.setup,
|
||||
OrgToJsonl,
|
||||
data,
|
||||
True,
|
||||
True,
|
||||
default_user,
|
||||
)
|
||||
|
||||
query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
hits, entries = await text_search.query(
|
||||
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
|
||||
)
|
||||
|
||||
results = text_search.collate_results(hits, entries, count=1)
|
||||
hits = await text_search.query(default_user, query)
|
||||
|
||||
# Assert
|
||||
results = text_search.collate_results(hits)
|
||||
results = sorted(results, key=lambda x: float(x.score))[:1]
|
||||
# search results should contain "git clone" entry
|
||||
search_result = results[0].entry
|
||||
assert "git clone" in search_result
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
||||
@pytest.mark.django_db
|
||||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||
max_tokens = 256
|
||||
|
@ -137,47 +172,46 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
|
|||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model after adding new org-mode file
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
|
||||
# Assert
|
||||
# verify newly added org-mode entry is split by max tokens
|
||||
assert len(initial_notes_model.entries) == 2
|
||||
assert len(initial_notes_model.corpus_embeddings) == 2
|
||||
record = caplog.records[1]
|
||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
|
||||
@pytest.mark.django_db
|
||||
def test_entry_chunking_by_max_tokens_not_full_corpus(
|
||||
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||
data = {
|
||||
"readme.org": """
|
||||
* Khoj
|
||||
/Allow natural language search on user content like notes, images using transformer based models/
|
||||
/Allow natural language search on user content like notes, images using transformer based models/
|
||||
|
||||
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
|
||||
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
|
||||
|
||||
** Dependencies
|
||||
- Python3
|
||||
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
|
||||
- Python3
|
||||
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
|
||||
|
||||
** Install
|
||||
#+begin_src shell
|
||||
git clone https://github.com/khoj-ai/khoj && cd khoj
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
#+end_src"""
|
||||
#+begin_src shell
|
||||
git clone https://github.com/khoj-ai/khoj && cd khoj
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
#+end_src"""
|
||||
}
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
data,
|
||||
org_config_with_only_new_file,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
max_tokens = 256
|
||||
|
@ -191,64 +225,57 @@ def test_entry_chunking_by_max_tokens_not_full_corpus(
|
|||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model after adding new org-mode file
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl,
|
||||
data,
|
||||
org_config_with_only_new_file,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
full_corpus=False,
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
data,
|
||||
regenerate=False,
|
||||
full_corpus=False,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
record = caplog.records[1]
|
||||
|
||||
# Assert
|
||||
# verify newly added org-mode entry is split by max tokens
|
||||
assert len(initial_notes_model.entries) == 5
|
||||
assert len(initial_notes_model.corpus_embeddings) == 5
|
||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_regenerate_index_with_new_entry(
|
||||
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
|
||||
content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
assert len(initial_notes_model.entries) == 10
|
||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
content_config.org.input_files = [f"{new_org_file}"]
|
||||
org_config.input_files = [f"{new_org_file}"]
|
||||
with open(new_org_file, "w") as f:
|
||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||
|
||||
data = get_org_files(content_config.org)
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
regenerated_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert len(regenerated_notes_model.entries) == 11
|
||||
assert len(regenerated_notes_model.corpus_embeddings) == 11
|
||||
|
||||
# verify new entry appended to index, without disrupting order or content of existing entries
|
||||
error_details = compare_index(initial_notes_model, regenerated_notes_model)
|
||||
if error_details:
|
||||
pytest.fail(error_details, False)
|
||||
|
||||
# Cleanup
|
||||
# reset input_files in config to empty list
|
||||
content_config.org.input_files = []
|
||||
assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message
|
||||
verify_embeddings(11, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_duplicate_entries_in_stable_order(
|
||||
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
@ -262,30 +289,26 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
|||
|
||||
# Act
|
||||
# load embeddings, entries, notes model after adding new org-mode file
|
||||
initial_index = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# update embeddings, entries, notes model after adding new org-mode file
|
||||
updated_index = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert len(initial_index.entries) == len(updated_index.entries) == 1
|
||||
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1
|
||||
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
|
||||
|
||||
# verify the same entry is added even when there are multiple duplicate entries
|
||||
error_details = compare_index(initial_index, updated_index)
|
||||
if error_details:
|
||||
pytest.fail(error_details)
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
||||
|
@ -296,9 +319,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
|
|||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
initial_index = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
# update embeddings, entries, notes model after removing an entry from the org file
|
||||
with open(new_file_to_index, "w") as f:
|
||||
|
@ -307,87 +329,65 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
|
|||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
updated_index = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert len(initial_index.entries) == len(updated_index.entries) + 1
|
||||
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
|
||||
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message
|
||||
|
||||
# verify the same entry is added even when there are multiple duplicate entries
|
||||
error_details = compare_index(updated_index, initial_index)
|
||||
if error_details:
|
||||
pytest.fail(error_details)
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
|
||||
)
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
with open(new_org_file, "w") as f:
|
||||
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
f.write(new_entry)
|
||||
|
||||
data = get_org_files(content_config.org)
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# update embeddings, entries with the newly added note
|
||||
content_config.org.input_files = [f"{new_org_file}"]
|
||||
final_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1
|
||||
assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
|
||||
|
||||
# verify new entry appended to index, without disrupting order or content of existing entries
|
||||
error_details = compare_index(initial_notes_model, final_notes_model)
|
||||
if error_details:
|
||||
pytest.fail(error_details, False)
|
||||
|
||||
# Cleanup
|
||||
# reset input_files in config to empty list
|
||||
content_config.org.input_files = []
|
||||
verify_embeddings(11, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
||||
def test_text_search_setup_github(content_config: ContentConfig, search_models: SearchModels):
|
||||
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
|
||||
# Arrange
|
||||
github_config = GithubConfig.objects.filter(user=default_user).first()
|
||||
# Act
|
||||
# Regenerate github embeddings to test asymmetric setup without caching
|
||||
github_model = text_search.setup(
|
||||
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
|
||||
text_search.setup(
|
||||
GithubToJsonl,
|
||||
{},
|
||||
regenerate=True,
|
||||
user=default_user,
|
||||
config=github_config,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(github_model.entries) > 1
|
||||
embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count()
|
||||
assert embeddings > 1
|
||||
|
||||
|
||||
def compare_index(initial_notes_model, final_notes_model):
|
||||
mismatched_entries, mismatched_embeddings = [], []
|
||||
for index in range(len(initial_notes_model.entries)):
|
||||
if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
|
||||
mismatched_entries.append(index)
|
||||
|
||||
# verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
|
||||
for index in range(len(initial_notes_model.corpus_embeddings)):
|
||||
if not initial_notes_model.corpus_embeddings[index].allclose(final_notes_model.corpus_embeddings[index]):
|
||||
mismatched_embeddings.append(index)
|
||||
|
||||
error_details = ""
|
||||
if mismatched_entries:
|
||||
mismatched_entries_str = ",".join(map(str, mismatched_entries))
|
||||
error_details += f"Entries at {mismatched_entries_str} not equal\n"
|
||||
if mismatched_embeddings:
|
||||
mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
|
||||
error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
|
||||
|
||||
return error_details
|
||||
def verify_embeddings(expected_count, user):
|
||||
embeddings = Embeddings.objects.filter(user=user, file_type="org").count()
|
||||
assert embeddings == expected_count
|
||||
|
|
|
@ -3,68 +3,40 @@ from khoj.search_filter.word_filter import WordFilter
|
|||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
def test_no_word_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = "head tail"
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == False
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_word_exclude_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
q_with_exclude_filter = 'head -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(q_with_exclude_filter)
|
||||
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_word_include_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
query_with_include_filter = 'head +"include_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(query_with_include_filter)
|
||||
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {2, 3}
|
||||
|
||||
|
||||
def test_word_include_and_exclude_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
|
||||
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {2}
|
||||
|
||||
|
||||
def test_get_word_filter_terms():
|
||||
|
|
Loading…
Reference in a new issue