mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
[Multi-User Part 1]: Enable storage of settings for plaintext files based on user account (#498)
- Partition configuration for indexing local data based on user accounts - Store indexed data in an underlying postgres db using the `pgvector` extension - Add migrations for all relevant user data and embeddings generation. Very little performance optimization has been done for the lookup time - Apply filters using SQL queries - Start removing many server-level configuration settings - Configure GitHub test actions to run during any PR. Update the test action to run in a containerized environment with a DB. - Update the Docker image and docker-compose.yml to work with the new application design
This commit is contained in:
parent
963cd165eb
commit
216acf545f
60 changed files with 1827 additions and 1792 deletions
48
.github/workflows/pre-commit.yml
vendored
Normal file
48
.github/workflows/pre-commit.yml
vendored
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
name: pre-commit
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- src/**
|
||||||
|
- tests/**
|
||||||
|
- config/**
|
||||||
|
- pyproject.toml
|
||||||
|
- .pre-commit-config.yml
|
||||||
|
- .github/workflows/test.yml
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- src/khoj/**
|
||||||
|
- tests/**
|
||||||
|
- config/**
|
||||||
|
- pyproject.toml
|
||||||
|
- .pre-commit-config.yml
|
||||||
|
- .github/workflows/test.yml
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
name: Run Tests
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: 3.11
|
||||||
|
|
||||||
|
- name: ⏬️ Install Dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt update && sudo apt install -y libegl1
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
|
||||||
|
- name: ⬇️ Install Application
|
||||||
|
run: pip install --upgrade .[dev]
|
||||||
|
|
||||||
|
- name: 🌡️ Validate Application
|
||||||
|
run: pre-commit run --hook-stage manual --all
|
50
.github/workflows/test.yml
vendored
50
.github/workflows/test.yml
vendored
|
@ -2,10 +2,8 @@ name: test
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
|
||||||
- 'master'
|
|
||||||
paths:
|
paths:
|
||||||
- src/khoj/**
|
- src/**
|
||||||
- tests/**
|
- tests/**
|
||||||
- config/**
|
- config/**
|
||||||
- pyproject.toml
|
- pyproject.toml
|
||||||
|
@ -13,7 +11,7 @@ on:
|
||||||
- .github/workflows/test.yml
|
- .github/workflows/test.yml
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'master'
|
- master
|
||||||
paths:
|
paths:
|
||||||
- src/khoj/**
|
- src/khoj/**
|
||||||
- tests/**
|
- tests/**
|
||||||
|
@ -26,6 +24,7 @@ jobs:
|
||||||
test:
|
test:
|
||||||
name: Run Tests
|
name: Run Tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
container: ubuntu:jammy
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
|
@ -33,6 +32,17 @@ jobs:
|
||||||
- '3.9'
|
- '3.9'
|
||||||
- '3.10'
|
- '3.10'
|
||||||
- '3.11'
|
- '3.11'
|
||||||
|
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: ankane/pgvector
|
||||||
|
env:
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
ports:
|
||||||
|
- 5432:5432
|
||||||
|
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
|
@ -43,17 +53,37 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
|
|
||||||
- name: ⏬️ Install Dependencies
|
- name: Install Git
|
||||||
run: |
|
run: |
|
||||||
sudo apt update && sudo apt install -y libegl1
|
apt update && apt install -y git
|
||||||
|
|
||||||
|
- name: ⏬️ Install Dependencies
|
||||||
|
env:
|
||||||
|
DEBIAN_FRONTEND: noninteractive
|
||||||
|
run: |
|
||||||
|
apt update && apt install -y libegl1 sqlite3 libsqlite3-dev libsqlite3-0
|
||||||
|
|
||||||
|
- name: ⬇️ Install Postgres
|
||||||
|
env:
|
||||||
|
DEBIAN_FRONTEND: noninteractive
|
||||||
|
run : |
|
||||||
|
apt install -y postgresql postgresql-client && apt install -y postgresql-server-dev-14
|
||||||
|
|
||||||
|
- name: ⬇️ Install pip
|
||||||
|
run: |
|
||||||
|
apt install -y python3-pip
|
||||||
|
python -m ensurepip --upgrade
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
|
||||||
- name: ⬇️ Install Application
|
- name: ⬇️ Install Application
|
||||||
run: pip install --upgrade .[dev]
|
run: sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && pip install --upgrade .[dev]
|
||||||
|
|
||||||
- name: 🌡️ Validate Application
|
|
||||||
run: pre-commit run --hook-stage manual --all
|
|
||||||
|
|
||||||
- name: 🧪 Test Application
|
- name: 🧪 Test Application
|
||||||
|
env:
|
||||||
|
POSTGRES_HOST: postgres
|
||||||
|
POSTGRES_PORT: 5432
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
POSTGRES_DB: postgres
|
||||||
run: pytest
|
run: pytest
|
||||||
timeout-minutes: 10
|
timeout-minutes: 10
|
||||||
|
|
11
Dockerfile
11
Dockerfile
|
@ -8,13 +8,20 @@ RUN apt update -y && apt -y install python3-pip git
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Install Application
|
# Install Application
|
||||||
COPY . .
|
COPY pyproject.toml .
|
||||||
|
COPY README.md .
|
||||||
RUN sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && \
|
RUN sed -i 's/dynamic = \["version"\]/version = "0.0.0"/' pyproject.toml && \
|
||||||
pip install --no-cache-dir .
|
pip install --no-cache-dir .
|
||||||
|
|
||||||
|
# Copy Source Code
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Set the PYTHONPATH environment variable in order for it to find the Django app.
|
||||||
|
ENV PYTHONPATH=/app/src:$PYTHONPATH
|
||||||
|
|
||||||
# Run the Application
|
# Run the Application
|
||||||
# There are more arguments required for the application to run,
|
# There are more arguments required for the application to run,
|
||||||
# but these should be passed in through the docker-compose.yml file.
|
# but these should be passed in through the docker-compose.yml file.
|
||||||
ARG PORT
|
ARG PORT
|
||||||
EXPOSE ${PORT}
|
EXPOSE ${PORT}
|
||||||
ENTRYPOINT ["khoj"]
|
ENTRYPOINT ["python3", "src/khoj/main.py"]
|
||||||
|
|
|
@ -1,7 +1,21 @@
|
||||||
version: "3.9"
|
version: "3.9"
|
||||||
services:
|
services:
|
||||||
|
database:
|
||||||
|
image: ankane/pgvector
|
||||||
|
ports:
|
||||||
|
- "5432:5432"
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
POSTGRES_DB: postgres
|
||||||
|
volumes:
|
||||||
|
- khoj_db:/var/lib/postgresql/data/
|
||||||
server:
|
server:
|
||||||
|
# Use the following line to use the latest version of khoj. Otherwise, it will build from source.
|
||||||
image: ghcr.io/khoj-ai/khoj:latest
|
image: ghcr.io/khoj-ai/khoj:latest
|
||||||
|
# Uncomment the following line to build from source. This will take a few minutes. Comment the next two lines out if you want to use the offiicial image.
|
||||||
|
# build:
|
||||||
|
# context: .
|
||||||
ports:
|
ports:
|
||||||
# If changing the local port (left hand side), no other changes required.
|
# If changing the local port (left hand side), no other changes required.
|
||||||
# If changing the remote port (right hand side),
|
# If changing the remote port (right hand side),
|
||||||
|
@ -26,8 +40,15 @@ services:
|
||||||
- ./tests/data/models/:/root/.khoj/search/
|
- ./tests/data/models/:/root/.khoj/search/
|
||||||
- khoj_config:/root/.khoj/
|
- khoj_config:/root/.khoj/
|
||||||
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
|
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
|
||||||
|
environment:
|
||||||
|
- POSTGRES_DB=postgres
|
||||||
|
- POSTGRES_USER=postgres
|
||||||
|
- POSTGRES_PASSWORD=postgres
|
||||||
|
- POSTGRES_HOST=database
|
||||||
|
- POSTGRES_PORT=5432
|
||||||
command: --host="0.0.0.0" --port=42110 -vv
|
command: --host="0.0.0.0" --port=42110 -vv
|
||||||
|
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
khoj_config:
|
khoj_config:
|
||||||
|
khoj_db:
|
||||||
|
|
|
@ -59,13 +59,15 @@ dependencies = [
|
||||||
"requests >= 2.26.0",
|
"requests >= 2.26.0",
|
||||||
"bs4 >= 0.0.1",
|
"bs4 >= 0.0.1",
|
||||||
"anyio == 3.7.1",
|
"anyio == 3.7.1",
|
||||||
"pymupdf >= 1.23.3",
|
"pymupdf >= 1.23.5",
|
||||||
"django == 4.2.5",
|
"django == 4.2.5",
|
||||||
"authlib == 1.2.1",
|
"authlib == 1.2.1",
|
||||||
"gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
"gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||||
"gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'",
|
"gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'",
|
||||||
"itsdangerous == 2.1.2",
|
"itsdangerous == 2.1.2",
|
||||||
"httpx == 0.25.0",
|
"httpx == 0.25.0",
|
||||||
|
"pgvector == 0.2.3",
|
||||||
|
"psycopg2-binary == 2.9.9",
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
@ -91,6 +93,8 @@ dev = [
|
||||||
"mypy >= 1.0.1",
|
"mypy >= 1.0.1",
|
||||||
"black >= 23.1.0",
|
"black >= 23.1.0",
|
||||||
"pre-commit >= 3.0.4",
|
"pre-commit >= 3.0.4",
|
||||||
|
"pytest-django == 4.5.2",
|
||||||
|
"pytest-asyncio == 0.21.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.hatch.version]
|
[tool.hatch.version]
|
||||||
|
|
4
pytest.ini
Normal file
4
pytest.ini
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
[pytest]
|
||||||
|
DJANGO_SETTINGS_MODULE = app.settings
|
||||||
|
pythonpath = . src
|
||||||
|
testpaths = tests
|
60
src/app/README.md
Normal file
60
src/app/README.md
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
# Django App
|
||||||
|
|
||||||
|
Khoj uses Django as the backend framework primarily for its powerful ORM and the admin interface. The Django app is located in the `src/app` directory. We have one installed app, under the `/database/` directory. This app is responsible for all the database related operations and holds all of our models. You can find the extensive Django documentation [here](https://docs.djangoproject.com/en/4.2/) 🌈.
|
||||||
|
|
||||||
|
## Setup (Docker)
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
1. Ensure you have [Docker](https://docs.docker.com/get-docker/) installed.
|
||||||
|
2. Ensure you have [Docker Compose](https://docs.docker.com/compose/install/) installed.
|
||||||
|
|
||||||
|
### Run
|
||||||
|
|
||||||
|
Using the `docker-compose.yml` file in the root directory, you can run the Khoj app using the following command:
|
||||||
|
```bash
|
||||||
|
docker-compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
## Setup (Local)
|
||||||
|
|
||||||
|
### Install dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e '.[dev]'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Setup the database
|
||||||
|
|
||||||
|
1. Ensure you have Postgres installed. For MacOS, you can use [Postgres.app](https://postgresapp.com/).
|
||||||
|
2. If you're not using Postgres.app, you may have to install the pgvector extension manually. You can find the instructions [here](https://github.com/pgvector/pgvector#installation). If you're using Postgres.app, you can skip this step. Reproduced instructions below for convenience.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /tmp
|
||||||
|
git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git
|
||||||
|
cd pgvector
|
||||||
|
make
|
||||||
|
make install # may need sudo
|
||||||
|
```
|
||||||
|
3. Create a database
|
||||||
|
|
||||||
|
### Make migrations
|
||||||
|
|
||||||
|
This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 src/manage.py makemigrations
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run migrations
|
||||||
|
|
||||||
|
This command will run any pending migrations in your application.
|
||||||
|
```bash
|
||||||
|
python3 src/manage.py migrate
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run the server
|
||||||
|
|
||||||
|
While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend.
|
||||||
|
```bash
|
||||||
|
python3 src/khoj/main.py
|
||||||
|
```
|
|
@ -77,8 +77,12 @@ WSGI_APPLICATION = "app.wsgi.application"
|
||||||
|
|
||||||
DATABASES = {
|
DATABASES = {
|
||||||
"default": {
|
"default": {
|
||||||
"ENGINE": "django.db.backends.sqlite3",
|
"ENGINE": "django.db.backends.postgresql",
|
||||||
"NAME": BASE_DIR / "db.sqlite3",
|
"HOST": os.getenv("POSTGRES_HOST", "localhost"),
|
||||||
|
"PORT": os.getenv("POSTGRES_PORT", "5432"),
|
||||||
|
"USER": os.getenv("POSTGRES_USER", "postgres"),
|
||||||
|
"NAME": os.getenv("POSTGRES_DB", "khoj"),
|
||||||
|
"PASSWORD": os.getenv("POSTGRES_PASSWORD", "postgres"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,30 @@
|
||||||
from typing import Type, TypeVar
|
from typing import Type, TypeVar, List
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.contrib.sessions.backends.db import SessionStore
|
from django.contrib.sessions.backends.db import SessionStore
|
||||||
|
from pgvector.django import CosineDistance
|
||||||
|
from django.db.models.manager import BaseManager
|
||||||
|
from django.db.models import Q
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
# Import sync_to_async from Django Channels
|
# Import sync_to_async from Django Channels
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from database.models import KhojUser, GoogleUser, NotionConfig
|
from database.models import (
|
||||||
|
KhojUser,
|
||||||
|
GoogleUser,
|
||||||
|
NotionConfig,
|
||||||
|
GithubConfig,
|
||||||
|
Embeddings,
|
||||||
|
GithubRepoConfig,
|
||||||
|
)
|
||||||
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=models.Model)
|
ModelType = TypeVar("ModelType", bound=models.Model)
|
||||||
|
|
||||||
|
@ -40,9 +55,7 @@ async def get_or_create_user(token: dict) -> KhojUser:
|
||||||
|
|
||||||
async def create_google_user(token: dict) -> KhojUser:
|
async def create_google_user(token: dict) -> KhojUser:
|
||||||
user_info = token.get("userinfo")
|
user_info = token.get("userinfo")
|
||||||
user = await KhojUser.objects.acreate(
|
user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
|
||||||
username=user_info.get("email"), email=user_info.get("email"), uuid=uuid.uuid4()
|
|
||||||
)
|
|
||||||
await user.asave()
|
await user.asave()
|
||||||
await GoogleUser.objects.acreate(
|
await GoogleUser.objects.acreate(
|
||||||
sub=user_info.get("sub"),
|
sub=user_info.get("sub"),
|
||||||
|
@ -76,3 +89,149 @@ async def retrieve_user(session_id: str) -> KhojUser:
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=401, detail="Invalid user")
|
raise HTTPException(status_code=401, detail="Invalid user")
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_users() -> BaseManager[KhojUser]:
|
||||||
|
return KhojUser.objects.all()
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_github_config(user: KhojUser):
|
||||||
|
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_notion_config(user: KhojUser):
|
||||||
|
config = NotionConfig.objects.filter(user=user).first()
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
|
||||||
|
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
|
||||||
|
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
|
||||||
|
await object.objects.filter(user=user).adelete()
|
||||||
|
await object.objects.acreate(
|
||||||
|
input_files=deduped_files,
|
||||||
|
input_filter=deduped_filters,
|
||||||
|
index_heading_entries=updated_config.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||||
|
config = await GithubConfig.objects.filter(user=user).afirst()
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
config = await GithubConfig.objects.acreate(pat_token=pat_token, user=user)
|
||||||
|
else:
|
||||||
|
config.pat_token = pat_token
|
||||||
|
await config.asave()
|
||||||
|
await config.githubrepoconfig.all().adelete()
|
||||||
|
|
||||||
|
for repo in repos:
|
||||||
|
await GithubRepoConfig.objects.acreate(
|
||||||
|
name=repo["name"], owner=repo["owner"], branch=repo["branch"], github_config=config
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsAdapters:
|
||||||
|
word_filer = WordFilter()
|
||||||
|
file_filter = FileFilter()
|
||||||
|
date_filter = DateFilter()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool:
|
||||||
|
return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_embedding_by_file(user: KhojUser, file_path: str):
|
||||||
|
deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete()
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_all_embeddings(user: KhojUser, file_type: str):
|
||||||
|
deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete()
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
|
||||||
|
return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]):
|
||||||
|
Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date):
|
||||||
|
return embeddings.filter(
|
||||||
|
embeddingsdates__date__gte=start_date,
|
||||||
|
embeddingsdates__date__lte=end_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def user_has_embeddings(user: KhojUser):
|
||||||
|
return await Embeddings.objects.filter(user=user).aexists()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
|
||||||
|
q_filter_terms = Q()
|
||||||
|
|
||||||
|
explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query)
|
||||||
|
file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query)
|
||||||
|
date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query)
|
||||||
|
|
||||||
|
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
||||||
|
return Embeddings.objects.filter(user=user)
|
||||||
|
|
||||||
|
for term in explicit_word_terms:
|
||||||
|
if term.startswith("+"):
|
||||||
|
q_filter_terms &= Q(raw__icontains=term[1:])
|
||||||
|
elif term.startswith("-"):
|
||||||
|
q_filter_terms &= ~Q(raw__icontains=term[1:])
|
||||||
|
|
||||||
|
q_file_filter_terms = Q()
|
||||||
|
|
||||||
|
if len(file_filters) > 0:
|
||||||
|
for term in file_filters:
|
||||||
|
q_file_filter_terms |= Q(file_path__regex=term)
|
||||||
|
|
||||||
|
q_filter_terms &= q_file_filter_terms
|
||||||
|
|
||||||
|
if len(date_filters) > 0:
|
||||||
|
min_date, max_date = date_filters
|
||||||
|
if min_date is not None:
|
||||||
|
# Convert the min_date timestamp to yyyy-mm-dd format
|
||||||
|
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
|
||||||
|
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
|
||||||
|
if max_date is not None:
|
||||||
|
# Convert the max_date timestamp to yyyy-mm-dd format
|
||||||
|
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
|
||||||
|
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
|
||||||
|
|
||||||
|
relevant_embeddings = Embeddings.objects.filter(user=user).filter(
|
||||||
|
q_filter_terms,
|
||||||
|
)
|
||||||
|
if file_type_filter:
|
||||||
|
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
|
||||||
|
return relevant_embeddings
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def search_with_embeddings(
|
||||||
|
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
|
||||||
|
):
|
||||||
|
relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter)
|
||||||
|
relevant_embeddings = relevant_embeddings.filter(user=user).annotate(
|
||||||
|
distance=CosineDistance("embeddings", embeddings)
|
||||||
|
)
|
||||||
|
if file_type_filter:
|
||||||
|
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
|
||||||
|
relevant_embeddings = relevant_embeddings.order_by("distance")
|
||||||
|
return relevant_embeddings[:max_results]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_unique_file_types(user: KhojUser):
|
||||||
|
return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct()
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
# Generated by Django 4.2.5 on 2023-09-27 17:52
|
|
||||||
|
|
||||||
from django.conf import settings
|
|
||||||
from django.db import migrations, models
|
|
||||||
import django.db.models.deletion
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
dependencies = [
|
|
||||||
("database", "0002_googleuser"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.CreateModel(
|
|
||||||
name="Configuration",
|
|
||||||
fields=[
|
|
||||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
migrations.CreateModel(
|
|
||||||
name="ConversationProcessorConfig",
|
|
||||||
fields=[
|
|
||||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
|
||||||
("conversation", models.JSONField()),
|
|
||||||
("enable_offline_chat", models.BooleanField(default=False)),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
migrations.CreateModel(
|
|
||||||
name="GithubConfig",
|
|
||||||
fields=[
|
|
||||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
|
||||||
("pat_token", models.CharField(max_length=200)),
|
|
||||||
("compressed_jsonl", models.CharField(max_length=300)),
|
|
||||||
("embeddings_file", models.CharField(max_length=300)),
|
|
||||||
(
|
|
||||||
"config",
|
|
||||||
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="khojuser",
|
|
||||||
name="uuid",
|
|
||||||
field=models.UUIDField(verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)),
|
|
||||||
preserve_default=False,
|
|
||||||
),
|
|
||||||
migrations.CreateModel(
|
|
||||||
name="NotionConfig",
|
|
||||||
fields=[
|
|
||||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
|
||||||
("token", models.CharField(max_length=200)),
|
|
||||||
("compressed_jsonl", models.CharField(max_length=300)),
|
|
||||||
("embeddings_file", models.CharField(max_length=300)),
|
|
||||||
(
|
|
||||||
"config",
|
|
||||||
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
migrations.CreateModel(
|
|
||||||
name="GithubRepoConfig",
|
|
||||||
fields=[
|
|
||||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
|
||||||
("name", models.CharField(max_length=200)),
|
|
||||||
("owner", models.CharField(max_length=200)),
|
|
||||||
("branch", models.CharField(max_length=200)),
|
|
||||||
(
|
|
||||||
"github_config",
|
|
||||||
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.githubconfig"),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="configuration",
|
|
||||||
name="user",
|
|
||||||
field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
|
||||||
),
|
|
||||||
]
|
|
10
src/database/migrations/0003_vector_extension.py
Normal file
10
src/database/migrations/0003_vector_extension.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
from django.db import migrations
|
||||||
|
from pgvector.django import VectorExtension
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0002_googleuser"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [VectorExtension()]
|
|
@ -0,0 +1,193 @@
|
||||||
|
# Generated by Django 4.2.5 on 2023-10-11 22:24
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
import pgvector.django
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0003_vector_extension"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="ConversationProcessorConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("conversation", models.JSONField()),
|
||||||
|
("enable_offline_chat", models.BooleanField(default=False)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="GithubConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("pat_token", models.CharField(max_length=200)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="khojuser",
|
||||||
|
name="uuid",
|
||||||
|
field=models.UUIDField(default=1234, verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)),
|
||||||
|
preserve_default=False,
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="NotionConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("token", models.CharField(max_length=200)),
|
||||||
|
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="LocalPlaintextConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("input_files", models.JSONField(default=list, null=True)),
|
||||||
|
("input_filter", models.JSONField(default=list, null=True)),
|
||||||
|
("index_heading_entries", models.BooleanField(default=False)),
|
||||||
|
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="LocalPdfConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("input_files", models.JSONField(default=list, null=True)),
|
||||||
|
("input_filter", models.JSONField(default=list, null=True)),
|
||||||
|
("index_heading_entries", models.BooleanField(default=False)),
|
||||||
|
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="LocalOrgConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("input_files", models.JSONField(default=list, null=True)),
|
||||||
|
("input_filter", models.JSONField(default=list, null=True)),
|
||||||
|
("index_heading_entries", models.BooleanField(default=False)),
|
||||||
|
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="LocalMarkdownConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("input_files", models.JSONField(default=list, null=True)),
|
||||||
|
("input_filter", models.JSONField(default=list, null=True)),
|
||||||
|
("index_heading_entries", models.BooleanField(default=False)),
|
||||||
|
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="GithubRepoConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("name", models.CharField(max_length=200)),
|
||||||
|
("owner", models.CharField(max_length=200)),
|
||||||
|
("branch", models.CharField(max_length=200)),
|
||||||
|
(
|
||||||
|
"github_config",
|
||||||
|
models.ForeignKey(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
related_name="githubrepoconfig",
|
||||||
|
to="database.githubconfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="githubconfig",
|
||||||
|
name="user",
|
||||||
|
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="Embeddings",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("embeddings", pgvector.django.VectorField(dimensions=384)),
|
||||||
|
("raw", models.TextField()),
|
||||||
|
("compiled", models.TextField()),
|
||||||
|
("heading", models.CharField(blank=True, default=None, max_length=1000, null=True)),
|
||||||
|
(
|
||||||
|
"file_type",
|
||||||
|
models.CharField(
|
||||||
|
choices=[
|
||||||
|
("image", "Image"),
|
||||||
|
("pdf", "Pdf"),
|
||||||
|
("plaintext", "Plaintext"),
|
||||||
|
("markdown", "Markdown"),
|
||||||
|
("org", "Org"),
|
||||||
|
("notion", "Notion"),
|
||||||
|
("github", "Github"),
|
||||||
|
("conversation", "Conversation"),
|
||||||
|
],
|
||||||
|
default="plaintext",
|
||||||
|
max_length=30,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("file_path", models.CharField(blank=True, default=None, max_length=400, null=True)),
|
||||||
|
("file_name", models.CharField(blank=True, default=None, max_length=400, null=True)),
|
||||||
|
("url", models.URLField(blank=True, default=None, max_length=400, null=True)),
|
||||||
|
("hashed_value", models.CharField(max_length=100)),
|
||||||
|
(
|
||||||
|
"user",
|
||||||
|
models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to=settings.AUTH_USER_MODEL,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
18
src/database/migrations/0005_embeddings_corpus_id.py
Normal file
18
src/database/migrations/0005_embeddings_corpus_id.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# Generated by Django 4.2.5 on 2023-10-13 02:39
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0004_conversationprocessorconfig_githubconfig_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="embeddings",
|
||||||
|
name="corpus_id",
|
||||||
|
field=models.UUIDField(default=uuid.uuid4, editable=False),
|
||||||
|
),
|
||||||
|
]
|
33
src/database/migrations/0006_embeddingsdates.py
Normal file
33
src/database/migrations/0006_embeddingsdates.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# Generated by Django 4.2.5 on 2023-10-13 19:28
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0005_embeddings_corpus_id"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="EmbeddingsDates",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("date", models.DateField()),
|
||||||
|
(
|
||||||
|
"embeddings",
|
||||||
|
models.ForeignKey(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
related_name="embeddings_dates",
|
||||||
|
to="database.embeddings",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"indexes": [models.Index(fields=["date"], name="database_em_date_a1ba47_idx")],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
|
@ -2,11 +2,25 @@ import uuid
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.contrib.auth.models import AbstractUser
|
from django.contrib.auth.models import AbstractUser
|
||||||
|
from pgvector.django import VectorField
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(models.Model):
|
||||||
|
created_at = models.DateTimeField(auto_now_add=True)
|
||||||
|
updated_at = models.DateTimeField(auto_now=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
|
||||||
class KhojUser(AbstractUser):
|
class KhojUser(AbstractUser):
|
||||||
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
|
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
|
||||||
|
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
if not self.uuid:
|
||||||
|
self.uuid = uuid.uuid4()
|
||||||
|
super().save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class GoogleUser(models.Model):
|
class GoogleUser(models.Model):
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
|
@ -23,31 +37,85 @@ class GoogleUser(models.Model):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class Configuration(models.Model):
|
class NotionConfig(BaseModel):
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
|
||||||
|
|
||||||
|
|
||||||
class NotionConfig(models.Model):
|
|
||||||
token = models.CharField(max_length=200)
|
token = models.CharField(max_length=200)
|
||||||
compressed_jsonl = models.CharField(max_length=300)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
embeddings_file = models.CharField(max_length=300)
|
|
||||||
config = models.OneToOneField(Configuration, on_delete=models.CASCADE)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubConfig(models.Model):
|
class GithubConfig(BaseModel):
|
||||||
pat_token = models.CharField(max_length=200)
|
pat_token = models.CharField(max_length=200)
|
||||||
compressed_jsonl = models.CharField(max_length=300)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
embeddings_file = models.CharField(max_length=300)
|
|
||||||
config = models.OneToOneField(Configuration, on_delete=models.CASCADE)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubRepoConfig(models.Model):
|
class GithubRepoConfig(BaseModel):
|
||||||
name = models.CharField(max_length=200)
|
name = models.CharField(max_length=200)
|
||||||
owner = models.CharField(max_length=200)
|
owner = models.CharField(max_length=200)
|
||||||
branch = models.CharField(max_length=200)
|
branch = models.CharField(max_length=200)
|
||||||
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE)
|
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessorConfig(models.Model):
|
class LocalOrgConfig(BaseModel):
|
||||||
|
input_files = models.JSONField(default=list, null=True)
|
||||||
|
input_filter = models.JSONField(default=list, null=True)
|
||||||
|
index_heading_entries = models.BooleanField(default=False)
|
||||||
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalMarkdownConfig(BaseModel):
|
||||||
|
input_files = models.JSONField(default=list, null=True)
|
||||||
|
input_filter = models.JSONField(default=list, null=True)
|
||||||
|
index_heading_entries = models.BooleanField(default=False)
|
||||||
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalPdfConfig(BaseModel):
|
||||||
|
input_files = models.JSONField(default=list, null=True)
|
||||||
|
input_filter = models.JSONField(default=list, null=True)
|
||||||
|
index_heading_entries = models.BooleanField(default=False)
|
||||||
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalPlaintextConfig(BaseModel):
|
||||||
|
input_files = models.JSONField(default=list, null=True)
|
||||||
|
input_filter = models.JSONField(default=list, null=True)
|
||||||
|
index_heading_entries = models.BooleanField(default=False)
|
||||||
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationProcessorConfig(BaseModel):
|
||||||
conversation = models.JSONField()
|
conversation = models.JSONField()
|
||||||
enable_offline_chat = models.BooleanField(default=False)
|
enable_offline_chat = models.BooleanField(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Embeddings(BaseModel):
|
||||||
|
class EmbeddingsType(models.TextChoices):
|
||||||
|
IMAGE = "image"
|
||||||
|
PDF = "pdf"
|
||||||
|
PLAINTEXT = "plaintext"
|
||||||
|
MARKDOWN = "markdown"
|
||||||
|
ORG = "org"
|
||||||
|
NOTION = "notion"
|
||||||
|
GITHUB = "github"
|
||||||
|
CONVERSATION = "conversation"
|
||||||
|
|
||||||
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
embeddings = VectorField(dimensions=384)
|
||||||
|
raw = models.TextField()
|
||||||
|
compiled = models.TextField()
|
||||||
|
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
|
||||||
|
file_type = models.CharField(max_length=30, choices=EmbeddingsType.choices, default=EmbeddingsType.PLAINTEXT)
|
||||||
|
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||||
|
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||||
|
url = models.URLField(max_length=400, default=None, null=True, blank=True)
|
||||||
|
hashed_value = models.CharField(max_length=100)
|
||||||
|
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsDates(BaseModel):
|
||||||
|
date = models.DateField()
|
||||||
|
embeddings = models.ForeignKey(Embeddings, on_delete=models.CASCADE, related_name="embeddings_dates")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
indexes = [
|
||||||
|
models.Index(fields=["date"]),
|
||||||
|
]
|
||||||
|
|
3
src/database/tests.py
Normal file
3
src/database/tests.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
# Create your tests here.
|
|
@ -30,6 +30,8 @@ from khoj.utils.helpers import resolve_absolute_path, merge_dicts
|
||||||
from khoj.utils.fs_syncer import collect_files
|
from khoj.utils.fs_syncer import collect_files
|
||||||
from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
|
from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
|
||||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||||
|
from database.models import KhojUser
|
||||||
|
from database.adapters import get_all_users
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -48,14 +50,28 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
from database.models import KhojUser
|
from database.models import KhojUser
|
||||||
|
|
||||||
self.khojuser_manager = KhojUser.objects
|
self.khojuser_manager = KhojUser.objects
|
||||||
|
self._initialize_default_user()
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def _initialize_default_user(self):
|
||||||
|
if not self.khojuser_manager.filter(username="default").exists():
|
||||||
|
self.khojuser_manager.create_user(
|
||||||
|
username="default",
|
||||||
|
email="default@example.com",
|
||||||
|
password="default",
|
||||||
|
)
|
||||||
|
|
||||||
async def authenticate(self, request):
|
async def authenticate(self, request):
|
||||||
current_user = request.session.get("user")
|
current_user = request.session.get("user")
|
||||||
if current_user and current_user.get("email"):
|
if current_user and current_user.get("email"):
|
||||||
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
|
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
|
||||||
if user:
|
if user:
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||||
|
elif not state.anonymous_mode:
|
||||||
|
user = await self.khojuser_manager.filter(username="default").afirst()
|
||||||
|
if user:
|
||||||
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||||
|
|
||||||
return AuthCredentials(), UnauthenticatedUser()
|
return AuthCredentials(), UnauthenticatedUser()
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,7 +94,11 @@ def initialize_server(config: Optional[FullConfig]):
|
||||||
|
|
||||||
|
|
||||||
def configure_server(
|
def configure_server(
|
||||||
config: FullConfig, regenerate: bool = False, search_type: Optional[SearchType] = None, init=False
|
config: FullConfig,
|
||||||
|
regenerate: bool = False,
|
||||||
|
search_type: Optional[SearchType] = None,
|
||||||
|
init=False,
|
||||||
|
user: KhojUser = None,
|
||||||
):
|
):
|
||||||
# Update Config
|
# Update Config
|
||||||
state.config = config
|
state.config = config
|
||||||
|
@ -95,7 +115,7 @@ def configure_server(
|
||||||
state.config_lock.acquire()
|
state.config_lock.acquire()
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types(state.config)
|
||||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
initialize_content(regenerate, search_type, init)
|
initialize_content(regenerate, search_type, init, user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"🚨 Failed to configure search models", exc_info=True)
|
logger.error(f"🚨 Failed to configure search models", exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
|
@ -103,7 +123,7 @@ def configure_server(
|
||||||
state.config_lock.release()
|
state.config_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False):
|
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
|
||||||
# Initialize Content from Config
|
# Initialize Content from Config
|
||||||
if state.search_models:
|
if state.search_models:
|
||||||
try:
|
try:
|
||||||
|
@ -112,7 +132,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
|
||||||
state.content_index = load_content(state.config.content_type, state.content_index, state.search_models)
|
state.content_index = load_content(state.config.content_type, state.content_index, state.search_models)
|
||||||
else:
|
else:
|
||||||
logger.info("📬 Updating content index...")
|
logger.info("📬 Updating content index...")
|
||||||
all_files = collect_files(state.config.content_type)
|
all_files = collect_files(user=user)
|
||||||
state.content_index = configure_content(
|
state.content_index = configure_content(
|
||||||
state.content_index,
|
state.content_index,
|
||||||
state.config.content_type,
|
state.config.content_type,
|
||||||
|
@ -120,6 +140,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
|
||||||
state.search_models,
|
state.search_models,
|
||||||
regenerate,
|
regenerate,
|
||||||
search_type,
|
search_type,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"🚨 Failed to index content", exc_info=True)
|
logger.error(f"🚨 Failed to index content", exc_info=True)
|
||||||
|
@ -152,9 +173,14 @@ if not state.demo:
|
||||||
def update_search_index():
|
def update_search_index():
|
||||||
try:
|
try:
|
||||||
logger.info("📬 Updating content index via Scheduler")
|
logger.info("📬 Updating content index via Scheduler")
|
||||||
all_files = collect_files(state.config.content_type)
|
for user in get_all_users():
|
||||||
|
all_files = collect_files(user=user)
|
||||||
|
state.content_index = configure_content(
|
||||||
|
state.content_index, state.config.content_type, all_files, state.search_models, user=user
|
||||||
|
)
|
||||||
|
all_files = collect_files(user=None)
|
||||||
state.content_index = configure_content(
|
state.content_index = configure_content(
|
||||||
state.content_index, state.config.content_type, all_files, state.search_models
|
state.content_index, state.config.content_type, all_files, state.search_models, user=None
|
||||||
)
|
)
|
||||||
logger.info("📪 Content index updated via Scheduler")
|
logger.info("📪 Content index updated via Scheduler")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -164,13 +190,9 @@ if not state.demo:
|
||||||
def configure_search_types(config: FullConfig):
|
def configure_search_types(config: FullConfig):
|
||||||
# Extract core search types
|
# Extract core search types
|
||||||
core_search_types = {e.name: e.value for e in SearchType}
|
core_search_types = {e.name: e.value for e in SearchType}
|
||||||
# Extract configured plugin search types
|
|
||||||
plugin_search_types = {}
|
|
||||||
if config.content_type and config.content_type.plugins:
|
|
||||||
plugin_search_types = {plugin_type: plugin_type for plugin_type in config.content_type.plugins.keys()}
|
|
||||||
|
|
||||||
# Dynamically generate search type enum by merging core search types with configured plugin search types
|
# Dynamically generate search type enum by merging core search types with configured plugin search types
|
||||||
return Enum("SearchType", merge_dicts(core_search_types, plugin_search_types))
|
return Enum("SearchType", merge_dicts(core_search_types, {}))
|
||||||
|
|
||||||
|
|
||||||
def configure_processor(
|
def configure_processor(
|
||||||
|
|
|
@ -10,12 +10,10 @@
|
||||||
<img class="card-icon" src="/static/assets/icons/github.svg" alt="Github">
|
<img class="card-icon" src="/static/assets/icons/github.svg" alt="Github">
|
||||||
<h3 class="card-title">
|
<h3 class="card-title">
|
||||||
Github
|
Github
|
||||||
{% if current_config.content_type.github %}
|
{% if current_model_state.github == False %}
|
||||||
{% if current_model_state.github == False %}
|
<img id="misconfigured-icon-github" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||||
<img id="misconfigured-icon-github" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
{% else %}
|
||||||
{% else %}
|
<img id="configured-icon-github" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||||
<img id="configured-icon-github" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
|
||||||
{% endif %}
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
</h3>
|
</h3>
|
||||||
</div>
|
</div>
|
||||||
|
@ -24,7 +22,7 @@
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
<a class="card-button" href="/config/content_type/github">
|
<a class="card-button" href="/config/content_type/github">
|
||||||
{% if current_config.content_type.github %}
|
{% if current_model_state.github %}
|
||||||
Update
|
Update
|
||||||
{% else %}
|
{% else %}
|
||||||
Setup
|
Setup
|
||||||
|
@ -32,7 +30,7 @@
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
{% if current_config.content_type.github %}
|
{% if current_model_state.github %}
|
||||||
<div id="clear-github" class="card-action-row">
|
<div id="clear-github" class="card-action-row">
|
||||||
<button class="card-button" onclick="clearContentType('github')">
|
<button class="card-button" onclick="clearContentType('github')">
|
||||||
Disable
|
Disable
|
||||||
|
@ -45,12 +43,10 @@
|
||||||
<img class="card-icon" src="/static/assets/icons/notion.svg" alt="Notion">
|
<img class="card-icon" src="/static/assets/icons/notion.svg" alt="Notion">
|
||||||
<h3 class="card-title">
|
<h3 class="card-title">
|
||||||
Notion
|
Notion
|
||||||
{% if current_config.content_type.notion %}
|
{% if current_model_state.notion == False %}
|
||||||
{% if current_model_state.notion == False %}
|
<img id="misconfigured-icon-notion" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||||
<img id="misconfigured-icon-notion" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
{% else %}
|
||||||
{% else %}
|
<img id="configured-icon-notion" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||||
<img id="configured-icon-notion" class="configured-icon" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
|
||||||
{% endif %}
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
</h3>
|
</h3>
|
||||||
</div>
|
</div>
|
||||||
|
@ -59,7 +55,7 @@
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
<a class="card-button" href="/config/content_type/notion">
|
<a class="card-button" href="/config/content_type/notion">
|
||||||
{% if current_config.content_type.content %}
|
{% if current_model_state.content %}
|
||||||
Update
|
Update
|
||||||
{% else %}
|
{% else %}
|
||||||
Setup
|
Setup
|
||||||
|
@ -67,7 +63,7 @@
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
{% if current_config.content_type.notion %}
|
{% if current_model_state.notion %}
|
||||||
<div id="clear-notion" class="card-action-row">
|
<div id="clear-notion" class="card-action-row">
|
||||||
<button class="card-button" onclick="clearContentType('notion')">
|
<button class="card-button" onclick="clearContentType('notion')">
|
||||||
Disable
|
Disable
|
||||||
|
@ -80,7 +76,7 @@
|
||||||
<img class="card-icon" src="/static/assets/icons/markdown.svg" alt="markdown">
|
<img class="card-icon" src="/static/assets/icons/markdown.svg" alt="markdown">
|
||||||
<h3 class="card-title">
|
<h3 class="card-title">
|
||||||
Markdown
|
Markdown
|
||||||
{% if current_config.content_type.markdown %}
|
{% if current_model_state.markdown %}
|
||||||
{% if current_model_state.markdown == False%}
|
{% if current_model_state.markdown == False%}
|
||||||
<img id="misconfigured-icon-markdown" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
<img id="misconfigured-icon-markdown" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||||
{% else %}
|
{% else %}
|
||||||
|
@ -94,7 +90,7 @@
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
<a class="card-button" href="/config/content_type/markdown">
|
<a class="card-button" href="/config/content_type/markdown">
|
||||||
{% if current_config.content_type.markdown %}
|
{% if current_model_state.markdown %}
|
||||||
Update
|
Update
|
||||||
{% else %}
|
{% else %}
|
||||||
Setup
|
Setup
|
||||||
|
@ -102,7 +98,7 @@
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
{% if current_config.content_type.markdown %}
|
{% if current_model_state.markdown %}
|
||||||
<div id="clear-markdown" class="card-action-row">
|
<div id="clear-markdown" class="card-action-row">
|
||||||
<button class="card-button" onclick="clearContentType('markdown')">
|
<button class="card-button" onclick="clearContentType('markdown')">
|
||||||
Disable
|
Disable
|
||||||
|
@ -115,7 +111,7 @@
|
||||||
<img class="card-icon" src="/static/assets/icons/org.svg" alt="org">
|
<img class="card-icon" src="/static/assets/icons/org.svg" alt="org">
|
||||||
<h3 class="card-title">
|
<h3 class="card-title">
|
||||||
Org
|
Org
|
||||||
{% if current_config.content_type.org %}
|
{% if current_model_state.org %}
|
||||||
{% if current_model_state.org == False %}
|
{% if current_model_state.org == False %}
|
||||||
<img id="misconfigured-icon-org" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
<img id="misconfigured-icon-org" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||||
{% else %}
|
{% else %}
|
||||||
|
@ -129,7 +125,7 @@
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
<a class="card-button" href="/config/content_type/org">
|
<a class="card-button" href="/config/content_type/org">
|
||||||
{% if current_config.content_type.org %}
|
{% if current_model_state.org %}
|
||||||
Update
|
Update
|
||||||
{% else %}
|
{% else %}
|
||||||
Setup
|
Setup
|
||||||
|
@ -137,7 +133,7 @@
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
{% if current_config.content_type.org %}
|
{% if current_model_state.org %}
|
||||||
<div id="clear-org" class="card-action-row">
|
<div id="clear-org" class="card-action-row">
|
||||||
<button class="card-button" onclick="clearContentType('org')">
|
<button class="card-button" onclick="clearContentType('org')">
|
||||||
Disable
|
Disable
|
||||||
|
@ -150,7 +146,7 @@
|
||||||
<img class="card-icon" src="/static/assets/icons/pdf.svg" alt="PDF">
|
<img class="card-icon" src="/static/assets/icons/pdf.svg" alt="PDF">
|
||||||
<h3 class="card-title">
|
<h3 class="card-title">
|
||||||
PDF
|
PDF
|
||||||
{% if current_config.content_type.pdf %}
|
{% if current_model_state.pdf %}
|
||||||
{% if current_model_state.pdf == False %}
|
{% if current_model_state.pdf == False %}
|
||||||
<img id="misconfigured-icon-pdf" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you need to click Configure.">
|
<img id="misconfigured-icon-pdf" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you need to click Configure.">
|
||||||
{% else %}
|
{% else %}
|
||||||
|
@ -164,7 +160,7 @@
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
<a class="card-button" href="/config/content_type/pdf">
|
<a class="card-button" href="/config/content_type/pdf">
|
||||||
{% if current_config.content_type.pdf %}
|
{% if current_model_state.pdf %}
|
||||||
Update
|
Update
|
||||||
{% else %}
|
{% else %}
|
||||||
Setup
|
Setup
|
||||||
|
@ -172,7 +168,7 @@
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
{% if current_config.content_type.pdf %}
|
{% if current_model_state.pdf %}
|
||||||
<div id="clear-pdf" class="card-action-row">
|
<div id="clear-pdf" class="card-action-row">
|
||||||
<button class="card-button" onclick="clearContentType('pdf')">
|
<button class="card-button" onclick="clearContentType('pdf')">
|
||||||
Disable
|
Disable
|
||||||
|
@ -185,7 +181,7 @@
|
||||||
<img class="card-icon" src="/static/assets/icons/plaintext.svg" alt="Plaintext">
|
<img class="card-icon" src="/static/assets/icons/plaintext.svg" alt="Plaintext">
|
||||||
<h3 class="card-title">
|
<h3 class="card-title">
|
||||||
Plaintext
|
Plaintext
|
||||||
{% if current_config.content_type.plaintext %}
|
{% if current_model_state.plaintext %}
|
||||||
{% if current_model_state.plaintext == False %}
|
{% if current_model_state.plaintext == False %}
|
||||||
<img id="misconfigured-icon-plaintext" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you need to click Configure.">
|
<img id="misconfigured-icon-plaintext" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you need to click Configure.">
|
||||||
{% else %}
|
{% else %}
|
||||||
|
@ -199,7 +195,7 @@
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
<a class="card-button" href="/config/content_type/plaintext">
|
<a class="card-button" href="/config/content_type/plaintext">
|
||||||
{% if current_config.content_type.plaintext %}
|
{% if current_model_state.plaintext %}
|
||||||
Update
|
Update
|
||||||
{% else %}
|
{% else %}
|
||||||
Setup
|
Setup
|
||||||
|
@ -207,7 +203,7 @@
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
{% if current_config.content_type.plaintext %}
|
{% if current_model_state.plaintext %}
|
||||||
<div id="clear-plaintext" class="card-action-row">
|
<div id="clear-plaintext" class="card-action-row">
|
||||||
<button class="card-button" onclick="clearContentType('plaintext')">
|
<button class="card-button" onclick="clearContentType('plaintext')">
|
||||||
Disable
|
Disable
|
||||||
|
|
|
@ -38,24 +38,6 @@
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
</div>
|
</div>
|
||||||
<button type="button" id="add-repository-button">Add Repository</button>
|
<button type="button" id="add-repository-button">Add Repository</button>
|
||||||
<table style="display: none;" >
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<label for="embeddings-file">Embeddings File (Output)</label>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
<div class="section">
|
<div class="section">
|
||||||
<div id="success" style="display: none;"></div>
|
<div id="success" style="display: none;"></div>
|
||||||
<button id="submit" type="submit">Save</button>
|
<button id="submit" type="submit">Save</button>
|
||||||
|
@ -107,8 +89,6 @@
|
||||||
submit.addEventListener("click", function(event) {
|
submit.addEventListener("click", function(event) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
const compressed_jsonl = document.getElementById("compressed-jsonl").value;
|
|
||||||
const embeddings_file = document.getElementById("embeddings-file").value;
|
|
||||||
const pat_token = document.getElementById("pat-token").value;
|
const pat_token = document.getElementById("pat-token").value;
|
||||||
|
|
||||||
if (pat_token == "") {
|
if (pat_token == "") {
|
||||||
|
@ -154,8 +134,6 @@
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
"pat_token": pat_token,
|
"pat_token": pat_token,
|
||||||
"repos": repos,
|
"repos": repos,
|
||||||
"compressed_jsonl": compressed_jsonl,
|
|
||||||
"embeddings_file": embeddings_file,
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
|
|
|
@ -43,33 +43,6 @@
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
<table style="display: none;" >
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<label for="embeddings-file">Embeddings File (Output)</label>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<label for="index-heading-entries">Index Heading Entries</label>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<input type="text" id="index-heading-entries" name="index-heading-entries" value="{{ current_config['index_heading_entries'] }}">
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
<div class="section">
|
<div class="section">
|
||||||
<div id="success" style="display: none;" ></div>
|
<div id="success" style="display: none;" ></div>
|
||||||
<button id="submit" type="submit">Save</button>
|
<button id="submit" type="submit">Save</button>
|
||||||
|
@ -155,9 +128,8 @@
|
||||||
inputFilter = null;
|
inputFilter = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
var compressed_jsonl = document.getElementById("compressed-jsonl").value;
|
// var index_heading_entries = document.getElementById("index-heading-entries").value;
|
||||||
var embeddings_file = document.getElementById("embeddings-file").value;
|
var index_heading_entries = true;
|
||||||
var index_heading_entries = document.getElementById("index-heading-entries").value;
|
|
||||||
|
|
||||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
fetch('/api/config/data/content_type/{{ content_type }}', {
|
fetch('/api/config/data/content_type/{{ content_type }}', {
|
||||||
|
@ -169,8 +141,6 @@
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
"input_files": inputFiles,
|
"input_files": inputFiles,
|
||||||
"input_filter": inputFilter,
|
"input_filter": inputFilter,
|
||||||
"compressed_jsonl": compressed_jsonl,
|
|
||||||
"embeddings_file": embeddings_file,
|
|
||||||
"index_heading_entries": index_heading_entries
|
"index_heading_entries": index_heading_entries
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -20,24 +20,6 @@
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
<table style="display: none;" >
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<label for="compressed-jsonl">Compressed JSONL (Output)</label>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<input type="text" id="compressed-jsonl" name="compressed-jsonl" value="{{ current_config['compressed_jsonl'] }}">
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<label for="embeddings-file">Embeddings File (Output)</label>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<input type="text" id="embeddings-file" name="embeddings-file" value="{{ current_config['embeddings_file'] }}">
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
<div class="section">
|
<div class="section">
|
||||||
<div id="success" style="display: none;"></div>
|
<div id="success" style="display: none;"></div>
|
||||||
<button id="submit" type="submit">Save</button>
|
<button id="submit" type="submit">Save</button>
|
||||||
|
@ -51,8 +33,6 @@
|
||||||
submit.addEventListener("click", function(event) {
|
submit.addEventListener("click", function(event) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
const compressed_jsonl = document.getElementById("compressed-jsonl").value;
|
|
||||||
const embeddings_file = document.getElementById("embeddings-file").value;
|
|
||||||
const token = document.getElementById("token").value;
|
const token = document.getElementById("token").value;
|
||||||
|
|
||||||
if (token == "") {
|
if (token == "") {
|
||||||
|
@ -70,8 +50,6 @@
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
"token": token,
|
"token": token,
|
||||||
"compressed_jsonl": compressed_jsonl,
|
|
||||||
"embeddings_file": embeddings_file,
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
|
|
|
@ -172,7 +172,7 @@
|
||||||
url = createRequestUrl(query, type, results_count || 5, rerank);
|
url = createRequestUrl(query, type, results_count || 5, rerank);
|
||||||
fetch(url, {
|
fetch(url, {
|
||||||
headers: {
|
headers: {
|
||||||
"X-CSRFToken": csrfToken
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
|
@ -199,8 +199,8 @@
|
||||||
fetch("/api/config/types")
|
fetch("/api/config/types")
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(enabled_types => {
|
.then(enabled_types => {
|
||||||
// Show warning if no content types are enabled
|
// Show warning if no content types are enabled, or just one ("all")
|
||||||
if (enabled_types.detail) {
|
if (enabled_types[0] === "all" && enabled_types.length === 1) {
|
||||||
document.getElementById("results").innerHTML = "<div id='results-error'>To use Khoj search, setup your content plugins on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>.</div>";
|
document.getElementById("results").innerHTML = "<div id='results-error'>To use Khoj search, setup your content plugins on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>.</div>";
|
||||||
document.getElementById("query").setAttribute("disabled", "disabled");
|
document.getElementById("query").setAttribute("disabled", "disabled");
|
||||||
document.getElementById("query").setAttribute("placeholder", "Configure Khoj to enable search");
|
document.getElementById("query").setAttribute("placeholder", "Configure Khoj to enable search");
|
||||||
|
|
|
@ -24,10 +24,15 @@ from rich.logging import RichHandler
|
||||||
from django.core.asgi import get_asgi_application
|
from django.core.asgi import get_asgi_application
|
||||||
from django.core.management import call_command
|
from django.core.management import call_command
|
||||||
|
|
||||||
# Internal Packages
|
# Initialize Django
|
||||||
from khoj.configure import configure_routes, initialize_server, configure_middleware
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
|
||||||
from khoj.utils import state
|
django.setup()
|
||||||
from khoj.utils.cli import cli
|
|
||||||
|
# Initialize Django Database
|
||||||
|
call_command("migrate", "--noinput")
|
||||||
|
|
||||||
|
# Initialize Django Static Files
|
||||||
|
call_command("collectstatic", "--noinput")
|
||||||
|
|
||||||
# Initialize Django
|
# Initialize Django
|
||||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
|
||||||
|
@ -54,6 +59,11 @@ app.add_middleware(
|
||||||
# Set Locale
|
# Set Locale
|
||||||
locale.setlocale(locale.LC_ALL, "")
|
locale.setlocale(locale.LC_ALL, "")
|
||||||
|
|
||||||
|
# Internal Packages. We do this after setting up Django so that Django features are accessible to the app.
|
||||||
|
from khoj.configure import configure_routes, initialize_server, configure_middleware
|
||||||
|
from khoj.utils import state
|
||||||
|
from khoj.utils.cli import cli
|
||||||
|
|
||||||
# Setup Logger
|
# Setup Logger
|
||||||
rich_handler = RichHandler(rich_tracebacks=True)
|
rich_handler = RichHandler(rich_tracebacks=True)
|
||||||
rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
|
rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
|
||||||
|
@ -95,6 +105,8 @@ def run():
|
||||||
|
|
||||||
# Mount Django and Static Files
|
# Mount Django and Static Files
|
||||||
app.mount("/django", django_app, name="django")
|
app.mount("/django", django_app, name="django")
|
||||||
|
if not os.path.exists("static"):
|
||||||
|
os.mkdir("static")
|
||||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||||
|
|
||||||
# Configure Middleware
|
# Configure Middleware
|
||||||
|
@ -111,6 +123,7 @@ def set_state(args):
|
||||||
state.host = args.host
|
state.host = args.host
|
||||||
state.port = args.port
|
state.port = args.port
|
||||||
state.demo = args.demo
|
state.demo = args.demo
|
||||||
|
state.anonymous_mode = args.anonymous_mode
|
||||||
state.khoj_version = version("khoj-assistant")
|
state.khoj_version = version("khoj-assistant")
|
||||||
|
|
||||||
|
|
57
src/khoj/processor/embeddings.py
Normal file
57
src/khoj/processor/embeddings.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
from sentence_transformers import CrossEncoder
|
||||||
|
|
||||||
|
from khoj.utils.rawconfig import SearchResponse
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsModel:
|
||||||
|
def __init__(self):
|
||||||
|
self.model_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
|
||||||
|
encode_kwargs = {"normalize_embeddings": True}
|
||||||
|
# encode_kwargs = {}
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
# Use CUDA GPU
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
# Use Apple M1 Metal Acceleration
|
||||||
|
device = torch.device("mps")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
model_kwargs = {"device": device}
|
||||||
|
self.embeddings_model = HuggingFaceEmbeddings(
|
||||||
|
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed_query(self, query):
|
||||||
|
return self.embeddings_model.embed_query(query)
|
||||||
|
|
||||||
|
def embed_documents(self, docs):
|
||||||
|
return self.embeddings_model.embed_documents(docs)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEncoderModel:
|
||||||
|
def __init__(self):
|
||||||
|
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
# Use CUDA GPU
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
# Use Apple M1 Metal Acceleration
|
||||||
|
device = torch.device("mps")
|
||||||
|
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=device)
|
||||||
|
|
||||||
|
def predict(self, query, hits: List[SearchResponse]):
|
||||||
|
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
|
||||||
|
cross_scores = self.cross_encoder_model.predict(cross__inp)
|
||||||
|
return cross_scores
|
|
@ -2,7 +2,7 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union, Tuple
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import requests
|
import requests
|
||||||
|
@ -12,18 +12,31 @@ from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
|
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
|
||||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.jsonl import compress_jsonl_data
|
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
|
from database.models import Embeddings, GithubConfig, KhojUser
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GithubToJsonl(TextToJsonl):
|
class GithubToJsonl(TextEmbeddings):
|
||||||
def __init__(self, config: GithubContentConfig):
|
def __init__(self, config: GithubConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
raw_repos = config.githubrepoconfig.all()
|
||||||
|
repos = []
|
||||||
|
for repo in raw_repos:
|
||||||
|
repos.append(
|
||||||
|
GithubRepoConfig(
|
||||||
|
name=repo.name,
|
||||||
|
owner=repo.owner,
|
||||||
|
branch=repo.branch,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.config = GithubContentConfig(
|
||||||
|
pat_token=config.pat_token,
|
||||||
|
repos=repos,
|
||||||
|
)
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
self.session.headers.update({"Authorization": f"token {self.config.pat_token}"})
|
self.session.headers.update({"Authorization": f"token {self.config.pat_token}"})
|
||||||
|
|
||||||
|
@ -37,7 +50,9 @@ class GithubToJsonl(TextToJsonl):
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
def process(self, previous_entries=[], files=None, full_corpus=True):
|
def process(
|
||||||
|
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||||
|
) -> Tuple[int, int]:
|
||||||
if self.config.pat_token is None or self.config.pat_token == "":
|
if self.config.pat_token is None or self.config.pat_token == "":
|
||||||
logger.error(f"Github PAT token is not set. Skipping github content")
|
logger.error(f"Github PAT token is not set. Skipping github content")
|
||||||
raise ValueError("Github PAT token is not set. Skipping github content")
|
raise ValueError("Github PAT token is not set. Skipping github content")
|
||||||
|
@ -45,7 +60,7 @@ class GithubToJsonl(TextToJsonl):
|
||||||
for repo in self.config.repos:
|
for repo in self.config.repos:
|
||||||
current_entries += self.process_repo(repo)
|
current_entries += self.process_repo(repo)
|
||||||
|
|
||||||
return self.update_entries_with_ids(current_entries, previous_entries)
|
return self.update_entries_with_ids(current_entries, user=user)
|
||||||
|
|
||||||
def process_repo(self, repo: GithubRepoConfig):
|
def process_repo(self, repo: GithubRepoConfig):
|
||||||
repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}"
|
repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}"
|
||||||
|
@ -80,26 +95,18 @@ class GithubToJsonl(TextToJsonl):
|
||||||
current_entries += issue_entries
|
current_entries += issue_entries
|
||||||
|
|
||||||
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
|
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
|
||||||
current_entries = TextToJsonl.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = TextEmbeddings.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||||
|
|
||||||
return current_entries
|
return current_entries
|
||||||
|
|
||||||
def update_entries_with_ids(self, current_entries, previous_entries):
|
def update_entries_with_ids(self, current_entries, user: KhojUser = None):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
current_entries, previous_entries, key="compiled", logger=logger
|
current_entries, Embeddings.EmbeddingsType.GITHUB, key="compiled", logger=logger, user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Write github entries to JSONL file", logger):
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
|
||||||
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
|
||||||
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
|
|
||||||
|
|
||||||
return entries_with_ids
|
|
||||||
|
|
||||||
def get_files(self, repo_url: str, repo: GithubRepoConfig):
|
def get_files(self, repo_url: str, repo: GithubRepoConfig):
|
||||||
# Get the contents of the repository
|
# Get the contents of the repository
|
||||||
|
|
|
@ -1,91 +0,0 @@
|
||||||
# Standard Packages
|
|
||||||
import glob
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
# Internal Packages
|
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
|
||||||
from khoj.utils.helpers import get_absolute_path, timer
|
|
||||||
from khoj.utils.jsonl import load_jsonl, compress_jsonl_data
|
|
||||||
from khoj.utils.rawconfig import Entry
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class JsonlToJsonl(TextToJsonl):
|
|
||||||
# Define Functions
|
|
||||||
def process(self, previous_entries=[], files: dict[str, str] = {}, full_corpus: bool = True):
|
|
||||||
# Extract required fields from config
|
|
||||||
input_jsonl_files, input_jsonl_filter, output_file = (
|
|
||||||
self.config.input_files,
|
|
||||||
self.config.input_filter,
|
|
||||||
self.config.compressed_jsonl,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get Jsonl Input Files to Process
|
|
||||||
all_input_jsonl_files = JsonlToJsonl.get_jsonl_files(input_jsonl_files, input_jsonl_filter)
|
|
||||||
|
|
||||||
# Extract Entries from specified jsonl files
|
|
||||||
with timer("Parse entries from jsonl files", logger):
|
|
||||||
input_jsons = JsonlToJsonl.extract_jsonl_entries(all_input_jsonl_files)
|
|
||||||
current_entries = list(map(Entry.from_dict, input_jsons))
|
|
||||||
|
|
||||||
# Split entries by max tokens supported by model
|
|
||||||
with timer("Split entries by max token size supported by model", logger):
|
|
||||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
|
||||||
with timer("Identify new or updated entries", logger):
|
|
||||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
|
||||||
current_entries, previous_entries, key="compiled", logger=logger
|
|
||||||
)
|
|
||||||
|
|
||||||
with timer("Write entries to JSONL file", logger):
|
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
|
||||||
jsonl_data = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
|
||||||
compress_jsonl_data(jsonl_data, output_file)
|
|
||||||
|
|
||||||
return entries_with_ids
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_jsonl_files(jsonl_files=None, jsonl_file_filters=None):
|
|
||||||
"Get all jsonl files to process"
|
|
||||||
absolute_jsonl_files, filtered_jsonl_files = set(), set()
|
|
||||||
if jsonl_files:
|
|
||||||
absolute_jsonl_files = {get_absolute_path(jsonl_file) for jsonl_file in jsonl_files}
|
|
||||||
if jsonl_file_filters:
|
|
||||||
filtered_jsonl_files = {
|
|
||||||
filtered_file
|
|
||||||
for jsonl_file_filter in jsonl_file_filters
|
|
||||||
for filtered_file in glob.glob(get_absolute_path(jsonl_file_filter), recursive=True)
|
|
||||||
}
|
|
||||||
|
|
||||||
all_jsonl_files = sorted(absolute_jsonl_files | filtered_jsonl_files)
|
|
||||||
|
|
||||||
files_with_non_jsonl_extensions = {
|
|
||||||
jsonl_file for jsonl_file in all_jsonl_files if not jsonl_file.endswith(".jsonl")
|
|
||||||
}
|
|
||||||
if any(files_with_non_jsonl_extensions):
|
|
||||||
print(f"[Warning] There maybe non jsonl files in the input set: {files_with_non_jsonl_extensions}")
|
|
||||||
|
|
||||||
logger.debug(f"Processing files: {all_jsonl_files}")
|
|
||||||
|
|
||||||
return all_jsonl_files
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_jsonl_entries(jsonl_files):
|
|
||||||
"Extract entries from specified jsonl files"
|
|
||||||
entries = []
|
|
||||||
for jsonl_file in jsonl_files:
|
|
||||||
entries.extend(load_jsonl(Path(jsonl_file)))
|
|
||||||
return entries
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_entries_to_jsonl(entries: List[Entry]):
|
|
||||||
"Convert each entry to JSON and collate as JSONL"
|
|
||||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
|
|
@ -3,29 +3,28 @@ import logging
|
||||||
import re
|
import re
|
||||||
import urllib3
|
import urllib3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import Tuple, List
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.constants import empty_escape_sequences
|
from khoj.utils.constants import empty_escape_sequences
|
||||||
from khoj.utils.jsonl import compress_jsonl_data
|
from khoj.utils.rawconfig import Entry
|
||||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
from database.models import Embeddings, KhojUser
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MarkdownToJsonl(TextToJsonl):
|
class MarkdownToJsonl(TextEmbeddings):
|
||||||
def __init__(self, config: TextContentConfig):
|
def __init__(self):
|
||||||
super().__init__(config)
|
super().__init__()
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, previous_entries=[], files=None, full_corpus: bool = True):
|
def process(
|
||||||
|
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||||
|
) -> Tuple[int, int]:
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
output_file = self.config.compressed_jsonl
|
|
||||||
|
|
||||||
if not full_corpus:
|
if not full_corpus:
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
files_to_process = set(files) - deletion_file_names
|
files_to_process = set(files) - deletion_file_names
|
||||||
|
@ -45,19 +44,17 @@ class MarkdownToJsonl(TextToJsonl):
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
|
current_entries,
|
||||||
|
Embeddings.EmbeddingsType.MARKDOWN,
|
||||||
|
"compiled",
|
||||||
|
logger,
|
||||||
|
deletion_file_names,
|
||||||
|
user,
|
||||||
|
regenerate=regenerate,
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Write markdown entries to JSONL file", logger):
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
|
||||||
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
|
||||||
compress_jsonl_data(jsonl_data, output_file)
|
|
||||||
|
|
||||||
return entries_with_ids
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_markdown_entries(markdown_files):
|
def extract_markdown_entries(markdown_files):
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import requests
|
import requests
|
||||||
|
@ -7,9 +8,9 @@ import requests
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import Entry, NotionContentConfig
|
from khoj.utils.rawconfig import Entry, NotionContentConfig
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.jsonl import compress_jsonl_data
|
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
|
from database.models import Embeddings, KhojUser, NotionConfig
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
@ -49,10 +50,12 @@ class NotionBlockType(Enum):
|
||||||
CALLOUT = "callout"
|
CALLOUT = "callout"
|
||||||
|
|
||||||
|
|
||||||
class NotionToJsonl(TextToJsonl):
|
class NotionToJsonl(TextEmbeddings):
|
||||||
def __init__(self, config: NotionContentConfig):
|
def __init__(self, config: NotionConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = NotionContentConfig(
|
||||||
|
token=config.token,
|
||||||
|
)
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
self.session.headers.update({"Authorization": f"Bearer {config.token}", "Notion-Version": "2022-02-22"})
|
self.session.headers.update({"Authorization": f"Bearer {config.token}", "Notion-Version": "2022-02-22"})
|
||||||
self.unsupported_block_types = [
|
self.unsupported_block_types = [
|
||||||
|
@ -80,7 +83,9 @@ class NotionToJsonl(TextToJsonl):
|
||||||
|
|
||||||
self.body_params = {"page_size": 100}
|
self.body_params = {"page_size": 100}
|
||||||
|
|
||||||
def process(self, previous_entries=[], files=None, full_corpus=True):
|
def process(
|
||||||
|
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||||
|
) -> Tuple[int, int]:
|
||||||
current_entries = []
|
current_entries = []
|
||||||
|
|
||||||
# Get all pages
|
# Get all pages
|
||||||
|
@ -112,7 +117,7 @@ class NotionToJsonl(TextToJsonl):
|
||||||
page_entries = self.process_page(p_or_d)
|
page_entries = self.process_page(p_or_d)
|
||||||
current_entries.extend(page_entries)
|
current_entries.extend(page_entries)
|
||||||
|
|
||||||
return self.update_entries_with_ids(current_entries, previous_entries)
|
return self.update_entries_with_ids(current_entries, user)
|
||||||
|
|
||||||
def process_page(self, page):
|
def process_page(self, page):
|
||||||
page_id = page["id"]
|
page_id = page["id"]
|
||||||
|
@ -241,19 +246,11 @@ class NotionToJsonl(TextToJsonl):
|
||||||
title = None
|
title = None
|
||||||
return title, content
|
return title, content
|
||||||
|
|
||||||
def update_entries_with_ids(self, current_entries, previous_entries):
|
def update_entries_with_ids(self, current_entries, user: KhojUser = None):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
current_entries, previous_entries, key="compiled", logger=logger
|
current_entries, Embeddings.EmbeddingsType.NOTION, key="compiled", logger=logger, user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Write Notion entries to JSONL file", logger):
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
# Process Each Entry from all Notion entries
|
|
||||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
|
||||||
jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries)
|
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
|
||||||
compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
|
|
||||||
|
|
||||||
return entries_with_ids
|
|
||||||
|
|
|
@ -5,28 +5,26 @@ from typing import Iterable, List, Tuple
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.processor.org_mode import orgnode
|
from khoj.processor.org_mode import orgnode
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.jsonl import compress_jsonl_data
|
from khoj.utils.rawconfig import Entry
|
||||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
from database.models import Embeddings, KhojUser
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OrgToJsonl(TextToJsonl):
|
class OrgToJsonl(TextEmbeddings):
|
||||||
def __init__(self, config: TextContentConfig):
|
def __init__(self):
|
||||||
super().__init__(config)
|
super().__init__()
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(
|
||||||
self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True
|
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||||
) -> List[Tuple[int, Entry]]:
|
) -> Tuple[int, int]:
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
output_file = self.config.compressed_jsonl
|
index_heading_entries = True
|
||||||
index_heading_entries = self.config.index_heading_entries
|
|
||||||
|
|
||||||
if not full_corpus:
|
if not full_corpus:
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
|
@ -47,19 +45,17 @@ class OrgToJsonl(TextToJsonl):
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
|
current_entries,
|
||||||
|
Embeddings.EmbeddingsType.ORG,
|
||||||
|
"compiled",
|
||||||
|
logger,
|
||||||
|
deletion_file_names,
|
||||||
|
user,
|
||||||
|
regenerate=regenerate,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process Each Entry from All Notes Files
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
with timer("Write org entries to JSONL file", logger):
|
|
||||||
entries = map(lambda entry: entry[1], entries_with_ids)
|
|
||||||
jsonl_data = self.convert_org_entries_to_jsonl(entries)
|
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
|
||||||
compress_jsonl_data(jsonl_data, output_file)
|
|
||||||
|
|
||||||
return entries_with_ids
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_org_entries(org_files: dict[str, str]):
|
def extract_org_entries(org_files: dict[str, str]):
|
||||||
|
|
|
@ -1,28 +1,31 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from langchain.document_loaders import PyMuPDFLoader
|
from langchain.document_loaders import PyMuPDFLoader
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.jsonl import compress_jsonl_data
|
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
|
from database.models import Embeddings, KhojUser
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PdfToJsonl(TextToJsonl):
|
class PdfToJsonl(TextEmbeddings):
|
||||||
# Define Functions
|
def __init__(self):
|
||||||
def process(self, previous_entries=[], files: dict[str, str] = None, full_corpus: bool = True):
|
super().__init__()
|
||||||
# Extract required fields from config
|
|
||||||
output_file = self.config.compressed_jsonl
|
|
||||||
|
|
||||||
|
# Define Functions
|
||||||
|
def process(
|
||||||
|
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
# Extract required fields from config
|
||||||
if not full_corpus:
|
if not full_corpus:
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
files_to_process = set(files) - deletion_file_names
|
files_to_process = set(files) - deletion_file_names
|
||||||
|
@ -40,19 +43,17 @@ class PdfToJsonl(TextToJsonl):
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
|
current_entries,
|
||||||
|
Embeddings.EmbeddingsType.PDF,
|
||||||
|
"compiled",
|
||||||
|
logger,
|
||||||
|
deletion_file_names,
|
||||||
|
user,
|
||||||
|
regenerate=regenerate,
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Write PDF entries to JSONL file", logger):
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
|
||||||
jsonl_data = PdfToJsonl.convert_pdf_maps_to_jsonl(entries)
|
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
|
||||||
compress_jsonl_data(jsonl_data, output_file)
|
|
||||||
|
|
||||||
return entries_with_ids
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_pdf_entries(pdf_files):
|
def extract_pdf_entries(pdf_files):
|
||||||
|
@ -62,7 +63,7 @@ class PdfToJsonl(TextToJsonl):
|
||||||
entry_to_location_map = []
|
entry_to_location_map = []
|
||||||
for pdf_file in pdf_files:
|
for pdf_file in pdf_files:
|
||||||
try:
|
try:
|
||||||
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PyPDFLoader expects a file path
|
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
|
||||||
tmp_file = f"tmp_pdf_file.pdf"
|
tmp_file = f"tmp_pdf_file.pdf"
|
||||||
with open(f"{tmp_file}", "wb") as f:
|
with open(f"{tmp_file}", "wb") as f:
|
||||||
bytes = pdf_files[pdf_file]
|
bytes = pdf_files[pdf_file]
|
||||||
|
|
|
@ -4,22 +4,23 @@ from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.jsonl import compress_jsonl_data
|
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||||
from khoj.utils.rawconfig import Entry
|
from database.models import Embeddings, KhojUser, LocalPlaintextConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PlaintextToJsonl(TextToJsonl):
|
class PlaintextToJsonl(TextEmbeddings):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(
|
def process(
|
||||||
self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True
|
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||||
) -> List[Tuple[int, Entry]]:
|
) -> Tuple[int, int]:
|
||||||
output_file = self.config.compressed_jsonl
|
|
||||||
|
|
||||||
if not full_corpus:
|
if not full_corpus:
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
files_to_process = set(files) - deletion_file_names
|
files_to_process = set(files) - deletion_file_names
|
||||||
|
@ -37,19 +38,17 @@ class PlaintextToJsonl(TextToJsonl):
|
||||||
|
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
entries_with_ids = TextToJsonl.mark_entries_for_update(
|
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
|
||||||
current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
|
current_entries,
|
||||||
|
Embeddings.EmbeddingsType.PLAINTEXT,
|
||||||
|
key="compiled",
|
||||||
|
logger=logger,
|
||||||
|
deletion_filenames=deletion_file_names,
|
||||||
|
user=user,
|
||||||
|
regenerate=regenerate,
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Write entries to JSONL file", logger):
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
|
||||||
plaintext_data = PlaintextToJsonl.convert_entries_to_jsonl(entries)
|
|
||||||
|
|
||||||
# Compress JSONL formatted Data
|
|
||||||
compress_jsonl_data(plaintext_data, output_file)
|
|
||||||
|
|
||||||
return entries_with_ids
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
|
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
|
||||||
|
|
|
@ -2,24 +2,33 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable, List, Tuple, Set
|
import uuid
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Callable, List, Tuple, Set, Any
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.rawconfig import Entry, TextConfigBase
|
from khoj.utils.rawconfig import Entry
|
||||||
|
from khoj.processor.embeddings import EmbeddingsModel
|
||||||
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
|
from database.models import KhojUser, Embeddings, EmbeddingsDates
|
||||||
|
from database.adapters import EmbeddingsAdapters
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TextToJsonl(ABC):
|
class TextEmbeddings(ABC):
|
||||||
def __init__(self, config: TextConfigBase):
|
def __init__(self, config: Any = None):
|
||||||
|
self.embeddings_model = EmbeddingsModel()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.date_filter = DateFilter()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process(
|
def process(
|
||||||
self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True
|
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||||
) -> List[Tuple[int, Entry]]:
|
) -> Tuple[int, int]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -38,6 +47,7 @@ class TextToJsonl(ABC):
|
||||||
|
|
||||||
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
|
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
|
||||||
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
|
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
|
||||||
|
corpus_id = uuid.uuid4()
|
||||||
|
|
||||||
# Split entry into chunks of max tokens
|
# Split entry into chunks of max tokens
|
||||||
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
|
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
|
||||||
|
@ -57,11 +67,103 @@ class TextToJsonl(ABC):
|
||||||
raw=entry.raw,
|
raw=entry.raw,
|
||||||
heading=entry.heading,
|
heading=entry.heading,
|
||||||
file=entry.file,
|
file=entry.file,
|
||||||
|
corpus_id=corpus_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return chunked_entries
|
return chunked_entries
|
||||||
|
|
||||||
|
def update_embeddings(
|
||||||
|
self,
|
||||||
|
current_entries: List[Entry],
|
||||||
|
file_type: str,
|
||||||
|
key="compiled",
|
||||||
|
logger: logging.Logger = None,
|
||||||
|
deletion_filenames: Set[str] = None,
|
||||||
|
user: KhojUser = None,
|
||||||
|
regenerate: bool = False,
|
||||||
|
):
|
||||||
|
with timer("Construct current entry hashes", logger):
|
||||||
|
hashes_by_file = dict[str, set[str]]()
|
||||||
|
current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries))
|
||||||
|
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||||
|
for entry in tqdm(current_entries, desc="Hashing Entries"):
|
||||||
|
hashes_by_file.setdefault(entry.file, set()).add(TextEmbeddings.hash_func(key)(entry))
|
||||||
|
|
||||||
|
num_deleted_embeddings = 0
|
||||||
|
with timer("Preparing dataset for regeneration", logger):
|
||||||
|
if regenerate:
|
||||||
|
logger.info(f"Deleting all embeddings for file type {file_type}")
|
||||||
|
num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
|
||||||
|
|
||||||
|
num_new_embeddings = 0
|
||||||
|
with timer("Identify hashes for adding new entries", logger):
|
||||||
|
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
|
||||||
|
hashes_for_file = hashes_by_file[file]
|
||||||
|
hashes_to_process = set()
|
||||||
|
existing_entries = Embeddings.objects.filter(
|
||||||
|
user=user, hashed_value__in=hashes_for_file, file_type=file_type
|
||||||
|
)
|
||||||
|
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
||||||
|
hashes_to_process = hashes_for_file - existing_entry_hashes
|
||||||
|
# for hashed_val in hashes_for_file:
|
||||||
|
# if not EmbeddingsAdapters.does_embedding_exist(user, hashed_val):
|
||||||
|
# hashes_to_process.add(hashed_val)
|
||||||
|
|
||||||
|
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||||
|
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||||
|
embeddings = self.embeddings_model.embed_documents(data_to_embed)
|
||||||
|
|
||||||
|
with timer("Update the database with new vector embeddings", logger):
|
||||||
|
embeddings_to_create = []
|
||||||
|
for hashed_val, embedding in zip(hashes_to_process, embeddings):
|
||||||
|
entry = hash_to_current_entries[hashed_val]
|
||||||
|
embeddings_to_create.append(
|
||||||
|
Embeddings(
|
||||||
|
user=user,
|
||||||
|
embeddings=embedding,
|
||||||
|
raw=entry.raw,
|
||||||
|
compiled=entry.compiled,
|
||||||
|
heading=entry.heading,
|
||||||
|
file_path=entry.file,
|
||||||
|
file_type=file_type,
|
||||||
|
hashed_value=hashed_val,
|
||||||
|
corpus_id=entry.corpus_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_embeddings = Embeddings.objects.bulk_create(embeddings_to_create)
|
||||||
|
num_new_embeddings += len(new_embeddings)
|
||||||
|
|
||||||
|
dates_to_create = []
|
||||||
|
with timer("Create new date associations for new embeddings", logger):
|
||||||
|
for embedding in new_embeddings:
|
||||||
|
dates = self.date_filter.extract_dates(embedding.raw)
|
||||||
|
for date in dates:
|
||||||
|
dates_to_create.append(
|
||||||
|
EmbeddingsDates(
|
||||||
|
date=date,
|
||||||
|
embeddings=embedding,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
|
||||||
|
if len(new_dates) > 0:
|
||||||
|
logger.info(f"Created {len(new_dates)} new date entries")
|
||||||
|
|
||||||
|
with timer("Identify hashes for removed entries", logger):
|
||||||
|
for file in hashes_by_file:
|
||||||
|
existing_entry_hashes = EmbeddingsAdapters.get_existing_entry_hashes_by_file(user, file)
|
||||||
|
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
|
||||||
|
num_deleted_embeddings += len(to_delete_entry_hashes)
|
||||||
|
EmbeddingsAdapters.delete_embedding_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
||||||
|
|
||||||
|
with timer("Identify hashes for deleting entries", logger):
|
||||||
|
if deletion_filenames is not None:
|
||||||
|
for file_path in deletion_filenames:
|
||||||
|
deleted_count = EmbeddingsAdapters.delete_embedding_by_file(user, file_path)
|
||||||
|
num_deleted_embeddings += deleted_count
|
||||||
|
|
||||||
|
return num_new_embeddings, num_deleted_embeddings
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mark_entries_for_update(
|
def mark_entries_for_update(
|
||||||
current_entries: List[Entry],
|
current_entries: List[Entry],
|
||||||
|
@ -72,11 +174,11 @@ class TextToJsonl(ABC):
|
||||||
):
|
):
|
||||||
# Hash all current and previous entries to identify new entries
|
# Hash all current and previous entries to identify new entries
|
||||||
with timer("Hash previous, current entries", logger):
|
with timer("Hash previous, current entries", logger):
|
||||||
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
|
current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries))
|
||||||
previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries))
|
previous_entry_hashes = list(map(TextEmbeddings.hash_func(key), previous_entries))
|
||||||
if deletion_filenames is not None:
|
if deletion_filenames is not None:
|
||||||
deletion_entries = [entry for entry in previous_entries if entry.file in deletion_filenames]
|
deletion_entries = [entry for entry in previous_entries if entry.file in deletion_filenames]
|
||||||
deletion_entry_hashes = list(map(TextToJsonl.hash_func(key), deletion_entries))
|
deletion_entry_hashes = list(map(TextEmbeddings.hash_func(key), deletion_entries))
|
||||||
else:
|
else:
|
||||||
deletion_entry_hashes = []
|
deletion_entry_hashes = []
|
||||||
|
|
||||||
|
|
|
@ -2,14 +2,15 @@
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
import yaml
|
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional, Union, Any
|
from typing import List, Optional, Union, Any
|
||||||
|
import asyncio
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter, HTTPException, Header, Request
|
from fastapi import APIRouter, HTTPException, Header, Request, Depends
|
||||||
from sentence_transformers import util
|
from starlette.authentication import requires
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_processor, configure_server
|
from khoj.configure import configure_processor, configure_server
|
||||||
|
@ -20,7 +21,6 @@ from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.utils.config import TextSearchModel
|
from khoj.utils.config import TextSearchModel
|
||||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
|
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
ContentConfig,
|
|
||||||
FullConfig,
|
FullConfig,
|
||||||
ProcessorConfig,
|
ProcessorConfig,
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
|
@ -48,11 +48,74 @@ from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
|
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
|
|
||||||
|
from database import adapters
|
||||||
|
from database.adapters import EmbeddingsAdapters
|
||||||
|
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
|
||||||
|
|
||||||
|
|
||||||
# Initialize Router
|
# Initialize Router
|
||||||
api = APIRouter()
|
api = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def map_config_to_object(content_type: str):
|
||||||
|
if content_type == "org":
|
||||||
|
return LocalOrgConfig
|
||||||
|
if content_type == "markdown":
|
||||||
|
return LocalMarkdownConfig
|
||||||
|
if content_type == "pdf":
|
||||||
|
return LocalPdfConfig
|
||||||
|
if content_type == "plaintext":
|
||||||
|
return LocalPlaintextConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def map_config_to_db(config: FullConfig, user: KhojUser):
|
||||||
|
if config.content_type:
|
||||||
|
if config.content_type.org:
|
||||||
|
await LocalOrgConfig.objects.filter(user=user).adelete()
|
||||||
|
await LocalOrgConfig.objects.acreate(
|
||||||
|
input_files=config.content_type.org.input_files,
|
||||||
|
input_filter=config.content_type.org.input_filter,
|
||||||
|
index_heading_entries=config.content_type.org.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if config.content_type.markdown:
|
||||||
|
await LocalMarkdownConfig.objects.filter(user=user).adelete()
|
||||||
|
await LocalMarkdownConfig.objects.acreate(
|
||||||
|
input_files=config.content_type.markdown.input_files,
|
||||||
|
input_filter=config.content_type.markdown.input_filter,
|
||||||
|
index_heading_entries=config.content_type.markdown.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if config.content_type.pdf:
|
||||||
|
await LocalPdfConfig.objects.filter(user=user).adelete()
|
||||||
|
await LocalPdfConfig.objects.acreate(
|
||||||
|
input_files=config.content_type.pdf.input_files,
|
||||||
|
input_filter=config.content_type.pdf.input_filter,
|
||||||
|
index_heading_entries=config.content_type.pdf.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if config.content_type.plaintext:
|
||||||
|
await LocalPlaintextConfig.objects.filter(user=user).adelete()
|
||||||
|
await LocalPlaintextConfig.objects.acreate(
|
||||||
|
input_files=config.content_type.plaintext.input_files,
|
||||||
|
input_filter=config.content_type.plaintext.input_filter,
|
||||||
|
index_heading_entries=config.content_type.plaintext.index_heading_entries,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if config.content_type.github:
|
||||||
|
await adapters.set_user_github_config(
|
||||||
|
user=user,
|
||||||
|
pat_token=config.content_type.github.pat_token,
|
||||||
|
repos=config.content_type.github.repos,
|
||||||
|
)
|
||||||
|
if config.content_type.notion:
|
||||||
|
await adapters.set_notion_config(
|
||||||
|
user=user,
|
||||||
|
token=config.content_type.notion.token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# If it's a demo instance, prevent updating any of the configuration.
|
# If it's a demo instance, prevent updating any of the configuration.
|
||||||
if not state.demo:
|
if not state.demo:
|
||||||
|
|
||||||
|
@ -64,7 +127,10 @@ if not state.demo:
|
||||||
state.processor_config = configure_processor(state.config.processor)
|
state.processor_config = configure_processor(state.config.processor)
|
||||||
|
|
||||||
@api.get("/config/data", response_model=FullConfig)
|
@api.get("/config/data", response_model=FullConfig)
|
||||||
def get_config_data():
|
def get_config_data(request: Request):
|
||||||
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
|
enabled_content = EmbeddingsAdapters.get_unique_file_types(user)
|
||||||
|
|
||||||
return state.config
|
return state.config
|
||||||
|
|
||||||
@api.post("/config/data")
|
@api.post("/config/data")
|
||||||
|
@ -73,20 +139,19 @@ if not state.demo:
|
||||||
updated_config: FullConfig,
|
updated_config: FullConfig,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
):
|
):
|
||||||
state.config = updated_config
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
with open(state.config_file, "w") as outfile:
|
await map_config_to_db(updated_config, user)
|
||||||
yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
|
|
||||||
outfile.close()
|
|
||||||
|
|
||||||
configuration_update_metadata = dict()
|
configuration_update_metadata = {}
|
||||||
|
|
||||||
|
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
|
||||||
|
|
||||||
if state.config.content_type is not None:
|
if state.config.content_type is not None:
|
||||||
configuration_update_metadata["github"] = state.config.content_type.github is not None
|
configuration_update_metadata["github"] = "github" in enabled_content
|
||||||
configuration_update_metadata["notion"] = state.config.content_type.notion is not None
|
configuration_update_metadata["notion"] = "notion" in enabled_content
|
||||||
configuration_update_metadata["org"] = state.config.content_type.org is not None
|
configuration_update_metadata["org"] = "org" in enabled_content
|
||||||
configuration_update_metadata["pdf"] = state.config.content_type.pdf is not None
|
configuration_update_metadata["pdf"] = "pdf" in enabled_content
|
||||||
configuration_update_metadata["markdown"] = state.config.content_type.markdown is not None
|
configuration_update_metadata["markdown"] = "markdown" in enabled_content
|
||||||
configuration_update_metadata["plugins"] = state.config.content_type.plugins is not None
|
|
||||||
|
|
||||||
if state.config.processor is not None:
|
if state.config.processor is not None:
|
||||||
configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None
|
configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None
|
||||||
|
@ -101,6 +166,7 @@ if not state.demo:
|
||||||
return state.config
|
return state.config
|
||||||
|
|
||||||
@api.post("/config/data/content_type/github", status_code=200)
|
@api.post("/config/data/content_type/github", status_code=200)
|
||||||
|
@requires("authenticated")
|
||||||
async def set_content_config_github_data(
|
async def set_content_config_github_data(
|
||||||
request: Request,
|
request: Request,
|
||||||
updated_config: Union[GithubContentConfig, None],
|
updated_config: Union[GithubContentConfig, None],
|
||||||
|
@ -108,10 +174,13 @@ if not state.demo:
|
||||||
):
|
):
|
||||||
_initialize_config()
|
_initialize_config()
|
||||||
|
|
||||||
if not state.config.content_type:
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
state.config.content_type = ContentConfig(**{"github": updated_config})
|
|
||||||
else:
|
await adapters.set_user_github_config(
|
||||||
state.config.content_type.github = updated_config
|
user=user,
|
||||||
|
pat_token=updated_config.pat_token,
|
||||||
|
repos=updated_config.repos,
|
||||||
|
)
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -121,11 +190,7 @@ if not state.demo:
|
||||||
metadata={"content_type": "github"},
|
metadata={"content_type": "github"},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
return {"status": "ok"}
|
||||||
save_config_to_file_updated_state()
|
|
||||||
return {"status": "ok"}
|
|
||||||
except Exception as e:
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
||||||
@api.post("/config/data/content_type/notion", status_code=200)
|
@api.post("/config/data/content_type/notion", status_code=200)
|
||||||
async def set_content_config_notion_data(
|
async def set_content_config_notion_data(
|
||||||
|
@ -135,10 +200,12 @@ if not state.demo:
|
||||||
):
|
):
|
||||||
_initialize_config()
|
_initialize_config()
|
||||||
|
|
||||||
if not state.config.content_type:
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
state.config.content_type = ContentConfig(**{"notion": updated_config})
|
|
||||||
else:
|
await adapters.set_notion_config(
|
||||||
state.config.content_type.notion = updated_config
|
user=user,
|
||||||
|
token=updated_config.token,
|
||||||
|
)
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -148,11 +215,7 @@ if not state.demo:
|
||||||
metadata={"content_type": "notion"},
|
metadata={"content_type": "notion"},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
return {"status": "ok"}
|
||||||
save_config_to_file_updated_state()
|
|
||||||
return {"status": "ok"}
|
|
||||||
except Exception as e:
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
||||||
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
|
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
|
||||||
async def remove_content_config_data(
|
async def remove_content_config_data(
|
||||||
|
@ -160,8 +223,7 @@ if not state.demo:
|
||||||
content_type: str,
|
content_type: str,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if not state.config or not state.config.content_type:
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -171,31 +233,13 @@ if not state.demo:
|
||||||
metadata={"content_type": content_type},
|
metadata={"content_type": content_type},
|
||||||
)
|
)
|
||||||
|
|
||||||
if state.config.content_type:
|
content_object = map_config_to_object(content_type)
|
||||||
state.config.content_type[content_type] = None
|
await content_object.objects.filter(user=user).adelete()
|
||||||
|
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
|
||||||
|
|
||||||
if content_type == "github":
|
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
|
||||||
state.content_index.github = None
|
|
||||||
elif content_type == "notion":
|
|
||||||
state.content_index.notion = None
|
|
||||||
elif content_type == "plugins":
|
|
||||||
state.content_index.plugins = None
|
|
||||||
elif content_type == "pdf":
|
|
||||||
state.content_index.pdf = None
|
|
||||||
elif content_type == "markdown":
|
|
||||||
state.content_index.markdown = None
|
|
||||||
elif content_type == "org":
|
|
||||||
state.content_index.org = None
|
|
||||||
elif content_type == "plaintext":
|
|
||||||
state.content_index.plaintext = None
|
|
||||||
else:
|
|
||||||
logger.warning(f"Request to delete unknown content type: {content_type} via API")
|
|
||||||
|
|
||||||
try:
|
return {"status": "ok"}
|
||||||
save_config_to_file_updated_state()
|
|
||||||
return {"status": "ok"}
|
|
||||||
except Exception as e:
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
||||||
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
|
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
|
||||||
async def remove_processor_conversation_config_data(
|
async def remove_processor_conversation_config_data(
|
||||||
|
@ -228,6 +272,7 @@ if not state.demo:
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
@api.post("/config/data/content_type/{content_type}", status_code=200)
|
@api.post("/config/data/content_type/{content_type}", status_code=200)
|
||||||
|
# @requires("authenticated")
|
||||||
async def set_content_config_data(
|
async def set_content_config_data(
|
||||||
request: Request,
|
request: Request,
|
||||||
content_type: str,
|
content_type: str,
|
||||||
|
@ -236,10 +281,10 @@ if not state.demo:
|
||||||
):
|
):
|
||||||
_initialize_config()
|
_initialize_config()
|
||||||
|
|
||||||
if not state.config.content_type:
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
state.config.content_type = ContentConfig(**{content_type: updated_config})
|
|
||||||
else:
|
content_object = map_config_to_object(content_type)
|
||||||
state.config.content_type[content_type] = updated_config
|
await adapters.set_text_content_config(user, content_object, updated_config)
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -249,11 +294,7 @@ if not state.demo:
|
||||||
metadata={"content_type": content_type},
|
metadata={"content_type": content_type},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
return {"status": "ok"}
|
||||||
save_config_to_file_updated_state()
|
|
||||||
return {"status": "ok"}
|
|
||||||
except Exception as e:
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
||||||
@api.post("/config/data/processor/conversation/openai", status_code=200)
|
@api.post("/config/data/processor/conversation/openai", status_code=200)
|
||||||
async def set_processor_openai_config_data(
|
async def set_processor_openai_config_data(
|
||||||
|
@ -337,24 +378,23 @@ def get_default_config_data():
|
||||||
|
|
||||||
|
|
||||||
@api.get("/config/types", response_model=List[str])
|
@api.get("/config/types", response_model=List[str])
|
||||||
def get_config_types():
|
def get_config_types(
|
||||||
"""Get configured content types"""
|
request: Request,
|
||||||
if state.config is None or state.config.content_type is None:
|
):
|
||||||
raise HTTPException(
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
status_code=500,
|
|
||||||
detail="Content types not configured. Configure at least one content type on server and restart it.",
|
enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user)
|
||||||
)
|
|
||||||
|
configured_content_types = list(enabled_file_types)
|
||||||
|
|
||||||
|
if state.config and state.config.content_type:
|
||||||
|
for ctype in state.config.content_type.dict(exclude_none=True):
|
||||||
|
configured_content_types.append(ctype)
|
||||||
|
|
||||||
configured_content_types = state.config.content_type.dict(exclude_none=True)
|
|
||||||
return [
|
return [
|
||||||
search_type.value
|
search_type.value
|
||||||
for search_type in SearchType
|
for search_type in SearchType
|
||||||
if (
|
if (search_type.value in configured_content_types) or search_type == SearchType.All
|
||||||
search_type.value in configured_content_types
|
|
||||||
and getattr(state.content_index, search_type.value) is not None
|
|
||||||
)
|
|
||||||
or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
|
|
||||||
or search_type == SearchType.All
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -372,6 +412,7 @@ async def search(
|
||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
host: Optional[str] = Header(None),
|
host: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Run validation checks
|
# Run validation checks
|
||||||
|
@ -390,10 +431,11 @@ async def search(
|
||||||
search_futures: List[concurrent.futures.Future] = []
|
search_futures: List[concurrent.futures.Future] = []
|
||||||
|
|
||||||
# return cached results, if available
|
# return cached results, if available
|
||||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
if user:
|
||||||
if query_cache_key in state.query_cache:
|
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
||||||
logger.debug(f"Return response from query cache")
|
if query_cache_key in state.query_cache[user.uuid]:
|
||||||
return state.query_cache[query_cache_key]
|
logger.debug(f"Return response from query cache")
|
||||||
|
return state.query_cache[user.uuid][query_cache_key]
|
||||||
|
|
||||||
# Encode query with filter terms removed
|
# Encode query with filter terms removed
|
||||||
defiltered_query = user_query
|
defiltered_query = user_query
|
||||||
|
@ -407,84 +449,31 @@ async def search(
|
||||||
]
|
]
|
||||||
if text_search_models:
|
if text_search_models:
|
||||||
with timer("Encoding query took", logger=logger):
|
with timer("Encoding query took", logger=logger):
|
||||||
encoded_asymmetric_query = util.normalize_embeddings(
|
encoded_asymmetric_query = state.embeddings_model.embed_query(defiltered_query)
|
||||||
text_search_models[0].bi_encoder.encode(
|
|
||||||
[defiltered_query],
|
|
||||||
convert_to_tensor=True,
|
|
||||||
device=state.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
|
if t in [
|
||||||
# query org-mode notes
|
SearchType.All,
|
||||||
search_futures += [
|
SearchType.Org,
|
||||||
executor.submit(
|
SearchType.Markdown,
|
||||||
text_search.query,
|
SearchType.Github,
|
||||||
user_query,
|
SearchType.Notion,
|
||||||
state.search_models.text_search,
|
SearchType.Plaintext,
|
||||||
state.content_index.org,
|
]:
|
||||||
question_embedding=encoded_asymmetric_query,
|
|
||||||
rank_results=r or False,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
dedupe=dedupe or True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
if (
|
|
||||||
(t == SearchType.Markdown or t == SearchType.All)
|
|
||||||
and state.content_index.markdown
|
|
||||||
and state.search_models.text_search
|
|
||||||
):
|
|
||||||
# query markdown notes
|
# query markdown notes
|
||||||
search_futures += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
|
user,
|
||||||
user_query,
|
user_query,
|
||||||
state.search_models.text_search,
|
t,
|
||||||
state.content_index.markdown,
|
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r or False,
|
rank_results=r or False,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
dedupe=dedupe or True,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
if (
|
elif (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
|
||||||
(t == SearchType.Github or t == SearchType.All)
|
|
||||||
and state.content_index.github
|
|
||||||
and state.search_models.text_search
|
|
||||||
):
|
|
||||||
# query github issues
|
|
||||||
search_futures += [
|
|
||||||
executor.submit(
|
|
||||||
text_search.query,
|
|
||||||
user_query,
|
|
||||||
state.search_models.text_search,
|
|
||||||
state.content_index.github,
|
|
||||||
question_embedding=encoded_asymmetric_query,
|
|
||||||
rank_results=r or False,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
dedupe=dedupe or True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
|
|
||||||
# query pdf files
|
|
||||||
search_futures += [
|
|
||||||
executor.submit(
|
|
||||||
text_search.query,
|
|
||||||
user_query,
|
|
||||||
state.search_models.text_search,
|
|
||||||
state.content_index.pdf,
|
|
||||||
question_embedding=encoded_asymmetric_query,
|
|
||||||
rank_results=r or False,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
dedupe=dedupe or True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
|
|
||||||
# query images
|
# query images
|
||||||
search_futures += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
|
@ -497,70 +486,6 @@ async def search(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
if (
|
|
||||||
(t == SearchType.All or t in SearchType)
|
|
||||||
and state.content_index.plugins
|
|
||||||
and state.search_models.plugin_search
|
|
||||||
):
|
|
||||||
# query specified plugin type
|
|
||||||
# Get plugin content, search model for specified search type, or the first one if none specified
|
|
||||||
plugin_search = state.search_models.plugin_search.get(t.value) or next(
|
|
||||||
iter(state.search_models.plugin_search.values())
|
|
||||||
)
|
|
||||||
plugin_content = state.content_index.plugins.get(t.value) or next(
|
|
||||||
iter(state.content_index.plugins.values())
|
|
||||||
)
|
|
||||||
search_futures += [
|
|
||||||
executor.submit(
|
|
||||||
text_search.query,
|
|
||||||
user_query,
|
|
||||||
plugin_search,
|
|
||||||
plugin_content,
|
|
||||||
question_embedding=encoded_asymmetric_query,
|
|
||||||
rank_results=r or False,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
dedupe=dedupe or True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
if (
|
|
||||||
(t == SearchType.Notion or t == SearchType.All)
|
|
||||||
and state.content_index.notion
|
|
||||||
and state.search_models.text_search
|
|
||||||
):
|
|
||||||
# query notion pages
|
|
||||||
search_futures += [
|
|
||||||
executor.submit(
|
|
||||||
text_search.query,
|
|
||||||
user_query,
|
|
||||||
state.search_models.text_search,
|
|
||||||
state.content_index.notion,
|
|
||||||
question_embedding=encoded_asymmetric_query,
|
|
||||||
rank_results=r or False,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
dedupe=dedupe or True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
if (
|
|
||||||
(t == SearchType.Plaintext or t == SearchType.All)
|
|
||||||
and state.content_index.plaintext
|
|
||||||
and state.search_models.text_search
|
|
||||||
):
|
|
||||||
# query plaintext files
|
|
||||||
search_futures += [
|
|
||||||
executor.submit(
|
|
||||||
text_search.query,
|
|
||||||
user_query,
|
|
||||||
state.search_models.text_search,
|
|
||||||
state.content_index.plaintext,
|
|
||||||
question_embedding=encoded_asymmetric_query,
|
|
||||||
rank_results=r or False,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
dedupe=dedupe or True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Query across each requested content types in parallel
|
# Query across each requested content types in parallel
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
for search_future in concurrent.futures.as_completed(search_futures):
|
for search_future in concurrent.futures.as_completed(search_futures):
|
||||||
|
@ -576,15 +501,19 @@ async def search(
|
||||||
count=results_count,
|
count=results_count,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hits, entries = await search_future.result()
|
hits = await search_future.result()
|
||||||
# Collate results
|
# Collate results
|
||||||
results += text_search.collate_results(hits, entries, results_count)
|
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||||
|
|
||||||
# Sort results across all content types and take top results
|
if r:
|
||||||
results = sorted(results, key=lambda x: float(x.score), reverse=True)[:results_count]
|
results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
|
||||||
|
else:
|
||||||
|
# Sort results across all content types and take top results
|
||||||
|
results = sorted(results, key=lambda x: float(x.score))[:results_count]
|
||||||
|
|
||||||
# Cache results
|
# Cache results
|
||||||
state.query_cache[query_cache_key] = results
|
if user:
|
||||||
|
state.query_cache[user.uuid][query_cache_key] = results
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -596,8 +525,6 @@ async def search(
|
||||||
host=host,
|
host=host,
|
||||||
)
|
)
|
||||||
|
|
||||||
state.previous_query = user_query
|
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
|
||||||
|
|
||||||
|
@ -614,12 +541,13 @@ def update(
|
||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
host: Optional[str] = Header(None),
|
host: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
if not state.config:
|
if not state.config:
|
||||||
error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
|
error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
|
||||||
logger.warning(error_msg)
|
logger.warning(error_msg)
|
||||||
raise HTTPException(status_code=500, detail=error_msg)
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
try:
|
try:
|
||||||
configure_server(state.config, regenerate=force, search_type=t)
|
configure_server(state.config, regenerate=force, search_type=t, user=user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"🚨 Failed to update server via API: {e}"
|
error_msg = f"🚨 Failed to update server via API: {e}"
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
|
@ -774,6 +702,7 @@ async def extract_references_and_questions(
|
||||||
n: int,
|
n: int,
|
||||||
conversation_type: ConversationCommand = ConversationCommand.Default,
|
conversation_type: ConversationCommand = ConversationCommand.Default,
|
||||||
):
|
):
|
||||||
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
# Load Conversation History
|
# Load Conversation History
|
||||||
meta_log = state.processor_config.conversation.meta_log
|
meta_log = state.processor_config.conversation.meta_log
|
||||||
|
|
||||||
|
@ -781,7 +710,7 @@ async def extract_references_and_questions(
|
||||||
compiled_references: List[Any] = []
|
compiled_references: List[Any] = []
|
||||||
inferred_queries: List[str] = []
|
inferred_queries: List[str] = []
|
||||||
|
|
||||||
if state.content_index is None:
|
if not EmbeddingsAdapters.user_has_embeddings(user=user):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,6 +10,7 @@ from khoj.utils.helpers import ConversationCommand, timer, log_telemetry
|
||||||
from khoj.processor.conversation.openai.gpt import converse
|
from khoj.processor.conversation.openai.gpt import converse
|
||||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
||||||
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
|
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
|
||||||
|
from database.models import KhojUser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -40,11 +41,13 @@ def update_telemetry_state(
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
|
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
||||||
user_state = {
|
user_state = {
|
||||||
"client_host": request.client.host if request.client else None,
|
"client_host": request.client.host if request.client else None,
|
||||||
"user_agent": user_agent or "unknown",
|
"user_agent": user_agent or "unknown",
|
||||||
"referer": referer or "unknown",
|
"referer": referer or "unknown",
|
||||||
"host": host or "unknown",
|
"host": host or "unknown",
|
||||||
|
"server_id": str(user.uuid) if user else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Union, Dict
|
from typing import Optional, Union, Dict
|
||||||
|
import asyncio
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter, HTTPException, Header, Request, Response, UploadFile
|
from fastapi import APIRouter, HTTPException, Header, Request, Response, UploadFile
|
||||||
|
@ -9,31 +10,30 @@ from khoj.routers.helpers import update_telemetry_state
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import state, constants
|
from khoj.utils import state, constants
|
||||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
|
||||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
|
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
|
||||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||||
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
|
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
|
||||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||||
from khoj.utils.rawconfig import ContentConfig, TextContentConfig
|
|
||||||
from khoj.search_type import text_search, image_search
|
from khoj.search_type import text_search, image_search
|
||||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||||
from khoj.utils.config import SearchModels
|
from khoj.utils.config import SearchModels
|
||||||
from khoj.utils.constants import default_config
|
|
||||||
from khoj.utils.helpers import LRU, get_file_type
|
from khoj.utils.helpers import LRU, get_file_type
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
ContentConfig,
|
ContentConfig,
|
||||||
FullConfig,
|
FullConfig,
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
)
|
)
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
|
||||||
from khoj.utils.config import (
|
from khoj.utils.config import (
|
||||||
ContentIndex,
|
ContentIndex,
|
||||||
SearchModels,
|
SearchModels,
|
||||||
)
|
)
|
||||||
|
from database.models import (
|
||||||
|
KhojUser,
|
||||||
|
GithubConfig,
|
||||||
|
NotionConfig,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -68,14 +68,14 @@ async def update(
|
||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
host: Optional[str] = Header(None),
|
host: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
if x_api_key != "secret":
|
if x_api_key != "secret":
|
||||||
raise HTTPException(status_code=401, detail="Invalid API Key")
|
raise HTTPException(status_code=401, detail="Invalid API Key")
|
||||||
state.config_lock.acquire()
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"📬 Updating content index via API call by {client} client")
|
logger.info(f"📬 Updating content index via API call by {client} client")
|
||||||
org_files: Dict[str, str] = {}
|
org_files: Dict[str, str] = {}
|
||||||
markdown_files: Dict[str, str] = {}
|
markdown_files: Dict[str, str] = {}
|
||||||
pdf_files: Dict[str, str] = {}
|
pdf_files: Dict[str, bytes] = {}
|
||||||
plaintext_files: Dict[str, str] = {}
|
plaintext_files: Dict[str, str] = {}
|
||||||
|
|
||||||
for file in files:
|
for file in files:
|
||||||
|
@ -86,7 +86,7 @@ async def update(
|
||||||
elif file_type == "markdown":
|
elif file_type == "markdown":
|
||||||
dict_to_update = markdown_files
|
dict_to_update = markdown_files
|
||||||
elif file_type == "pdf":
|
elif file_type == "pdf":
|
||||||
dict_to_update = pdf_files
|
dict_to_update = pdf_files # type: ignore
|
||||||
elif file_type == "plaintext":
|
elif file_type == "plaintext":
|
||||||
dict_to_update = plaintext_files
|
dict_to_update = plaintext_files
|
||||||
|
|
||||||
|
@ -120,30 +120,31 @@ async def update(
|
||||||
github=None,
|
github=None,
|
||||||
notion=None,
|
notion=None,
|
||||||
plaintext=None,
|
plaintext=None,
|
||||||
plugins=None,
|
|
||||||
)
|
)
|
||||||
state.config.content_type = default_content_config
|
state.config.content_type = default_content_config
|
||||||
save_config_to_file_updated_state()
|
save_config_to_file_updated_state()
|
||||||
configure_search(state.search_models, state.config.search_type)
|
configure_search(state.search_models, state.config.search_type)
|
||||||
|
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
state.content_index = configure_content(
|
loop = asyncio.get_event_loop()
|
||||||
|
state.content_index = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
configure_content,
|
||||||
state.content_index,
|
state.content_index,
|
||||||
state.config.content_type,
|
state.config.content_type,
|
||||||
indexer_input.dict(),
|
indexer_input.dict(),
|
||||||
state.search_models,
|
state.search_models,
|
||||||
regenerate=force,
|
force,
|
||||||
t=t,
|
t,
|
||||||
full_corpus=False,
|
False,
|
||||||
|
user,
|
||||||
)
|
)
|
||||||
|
logger.info(f"Finished processing batch indexing request")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
|
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
state.config_lock.release()
|
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -167,11 +168,6 @@ def configure_search(search_models: SearchModels, search_config: Optional[Search
|
||||||
if search_models is None:
|
if search_models is None:
|
||||||
search_models = SearchModels()
|
search_models = SearchModels()
|
||||||
|
|
||||||
# Initialize Search Models
|
|
||||||
if search_config.asymmetric:
|
|
||||||
logger.info("🔍 📜 Setting up text search model")
|
|
||||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
|
||||||
|
|
||||||
if search_config.image:
|
if search_config.image:
|
||||||
logger.info("🔍 🌄 Setting up image search model")
|
logger.info("🔍 🌄 Setting up image search model")
|
||||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
@ -187,16 +183,9 @@ def configure_content(
|
||||||
regenerate: bool = False,
|
regenerate: bool = False,
|
||||||
t: Optional[Union[state.SearchType, str]] = None,
|
t: Optional[Union[state.SearchType, str]] = None,
|
||||||
full_corpus: bool = True,
|
full_corpus: bool = True,
|
||||||
|
user: KhojUser = None,
|
||||||
) -> Optional[ContentIndex]:
|
) -> Optional[ContentIndex]:
|
||||||
def has_valid_text_config(config: TextContentConfig):
|
content_index = ContentIndex()
|
||||||
return config.input_files or config.input_filter
|
|
||||||
|
|
||||||
# Run Validation Checks
|
|
||||||
if content_config is None:
|
|
||||||
logger.warning("🚨 No Content configuration available.")
|
|
||||||
return None
|
|
||||||
if content_index is None:
|
|
||||||
content_index = ContentIndex()
|
|
||||||
|
|
||||||
if t in [type.value for type in state.SearchType]:
|
if t in [type.value for type in state.SearchType]:
|
||||||
t = state.SearchType(t).value
|
t = state.SearchType(t).value
|
||||||
|
@ -209,59 +198,30 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Org Notes Search
|
# Initialize Org Notes Search
|
||||||
if (
|
if (t == None or t == state.SearchType.Org.value) and files["org"]:
|
||||||
(t == None or t == state.SearchType.Org.value)
|
|
||||||
and ((content_config.org and has_valid_text_config(content_config.org)) or files["org"])
|
|
||||||
and search_models.text_search
|
|
||||||
):
|
|
||||||
if content_config.org == None:
|
|
||||||
logger.info("🦄 No configuration for orgmode notes. Using default configuration.")
|
|
||||||
default_configuration = default_config["content-type"]["org"] # type: ignore
|
|
||||||
content_config.org = TextContentConfig(
|
|
||||||
compressed_jsonl=default_configuration["compressed-jsonl"],
|
|
||||||
embeddings_file=default_configuration["embeddings-file"],
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("🦄 Setting up search for orgmode notes")
|
logger.info("🦄 Setting up search for orgmode notes")
|
||||||
# Extract Entries, Generate Notes Embeddings
|
# Extract Entries, Generate Notes Embeddings
|
||||||
content_index.org = text_search.setup(
|
text_search.setup(
|
||||||
OrgToJsonl,
|
OrgToJsonl,
|
||||||
files.get("org"),
|
files.get("org"),
|
||||||
content_config.org,
|
|
||||||
search_models.text_search.bi_encoder,
|
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
|
||||||
full_corpus=full_corpus,
|
full_corpus=full_corpus,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"🚨 Failed to setup org: {e}", exc_info=True)
|
logger.error(f"🚨 Failed to setup org: {e}", exc_info=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Markdown Search
|
# Initialize Markdown Search
|
||||||
if (
|
if (t == None or t == state.SearchType.Markdown.value) and files["markdown"]:
|
||||||
(t == None or t == state.SearchType.Markdown.value)
|
|
||||||
and ((content_config.markdown and has_valid_text_config(content_config.markdown)) or files["markdown"])
|
|
||||||
and search_models.text_search
|
|
||||||
and files["markdown"]
|
|
||||||
):
|
|
||||||
if content_config.markdown == None:
|
|
||||||
logger.info("💎 No configuration for markdown notes. Using default configuration.")
|
|
||||||
default_configuration = default_config["content-type"]["markdown"] # type: ignore
|
|
||||||
content_config.markdown = TextContentConfig(
|
|
||||||
compressed_jsonl=default_configuration["compressed-jsonl"],
|
|
||||||
embeddings_file=default_configuration["embeddings-file"],
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("💎 Setting up search for markdown notes")
|
logger.info("💎 Setting up search for markdown notes")
|
||||||
# Extract Entries, Generate Markdown Embeddings
|
# Extract Entries, Generate Markdown Embeddings
|
||||||
content_index.markdown = text_search.setup(
|
text_search.setup(
|
||||||
MarkdownToJsonl,
|
MarkdownToJsonl,
|
||||||
files.get("markdown"),
|
files.get("markdown"),
|
||||||
content_config.markdown,
|
|
||||||
search_models.text_search.bi_encoder,
|
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
|
||||||
full_corpus=full_corpus,
|
full_corpus=full_corpus,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -269,30 +229,15 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize PDF Search
|
# Initialize PDF Search
|
||||||
if (
|
if (t == None or t == state.SearchType.Pdf.value) and files["pdf"]:
|
||||||
(t == None or t == state.SearchType.Pdf.value)
|
|
||||||
and ((content_config.pdf and has_valid_text_config(content_config.pdf)) or files["pdf"])
|
|
||||||
and search_models.text_search
|
|
||||||
and files["pdf"]
|
|
||||||
):
|
|
||||||
if content_config.pdf == None:
|
|
||||||
logger.info("🖨️ No configuration for pdf notes. Using default configuration.")
|
|
||||||
default_configuration = default_config["content-type"]["pdf"] # type: ignore
|
|
||||||
content_config.pdf = TextContentConfig(
|
|
||||||
compressed_jsonl=default_configuration["compressed-jsonl"],
|
|
||||||
embeddings_file=default_configuration["embeddings-file"],
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("🖨️ Setting up search for pdf")
|
logger.info("🖨️ Setting up search for pdf")
|
||||||
# Extract Entries, Generate PDF Embeddings
|
# Extract Entries, Generate PDF Embeddings
|
||||||
content_index.pdf = text_search.setup(
|
text_search.setup(
|
||||||
PdfToJsonl,
|
PdfToJsonl,
|
||||||
files.get("pdf"),
|
files.get("pdf"),
|
||||||
content_config.pdf,
|
|
||||||
search_models.text_search.bi_encoder,
|
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
|
||||||
full_corpus=full_corpus,
|
full_corpus=full_corpus,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -300,30 +245,15 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Plaintext Search
|
# Initialize Plaintext Search
|
||||||
if (
|
if (t == None or t == state.SearchType.Plaintext.value) and files["plaintext"]:
|
||||||
(t == None or t == state.SearchType.Plaintext.value)
|
|
||||||
and ((content_config.plaintext and has_valid_text_config(content_config.plaintext)) or files["plaintext"])
|
|
||||||
and search_models.text_search
|
|
||||||
and files["plaintext"]
|
|
||||||
):
|
|
||||||
if content_config.plaintext == None:
|
|
||||||
logger.info("📄 No configuration for plaintext notes. Using default configuration.")
|
|
||||||
default_configuration = default_config["content-type"]["plaintext"] # type: ignore
|
|
||||||
content_config.plaintext = TextContentConfig(
|
|
||||||
compressed_jsonl=default_configuration["compressed-jsonl"],
|
|
||||||
embeddings_file=default_configuration["embeddings-file"],
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("📄 Setting up search for plaintext")
|
logger.info("📄 Setting up search for plaintext")
|
||||||
# Extract Entries, Generate Plaintext Embeddings
|
# Extract Entries, Generate Plaintext Embeddings
|
||||||
content_index.plaintext = text_search.setup(
|
text_search.setup(
|
||||||
PlaintextToJsonl,
|
PlaintextToJsonl,
|
||||||
files.get("plaintext"),
|
files.get("plaintext"),
|
||||||
content_config.plaintext,
|
|
||||||
search_models.text_search.bi_encoder,
|
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
|
||||||
full_corpus=full_corpus,
|
full_corpus=full_corpus,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -331,7 +261,12 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Image Search
|
# Initialize Image Search
|
||||||
if (t == None or t == state.SearchType.Image.value) and content_config.image and search_models.image_search:
|
if (
|
||||||
|
(t == None or t == state.SearchType.Image.value)
|
||||||
|
and content_config
|
||||||
|
and content_config.image
|
||||||
|
and search_models.image_search
|
||||||
|
):
|
||||||
logger.info("🌄 Setting up search for images")
|
logger.info("🌄 Setting up search for images")
|
||||||
# Extract Entries, Generate Image Embeddings
|
# Extract Entries, Generate Image Embeddings
|
||||||
content_index.image = image_search.setup(
|
content_index.image = image_search.setup(
|
||||||
|
@ -342,17 +277,17 @@ def configure_content(
|
||||||
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
|
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if (t == None or t == state.SearchType.Github.value) and content_config.github and search_models.text_search:
|
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||||
|
if (t == None or t == state.SearchType.Github.value) and github_config is not None:
|
||||||
logger.info("🐙 Setting up search for github")
|
logger.info("🐙 Setting up search for github")
|
||||||
# Extract Entries, Generate Github Embeddings
|
# Extract Entries, Generate Github Embeddings
|
||||||
content_index.github = text_search.setup(
|
text_search.setup(
|
||||||
GithubToJsonl,
|
GithubToJsonl,
|
||||||
None,
|
None,
|
||||||
content_config.github,
|
|
||||||
search_models.text_search.bi_encoder,
|
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
|
||||||
full_corpus=full_corpus,
|
full_corpus=full_corpus,
|
||||||
|
user=user,
|
||||||
|
config=github_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -360,42 +295,24 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Notion Search
|
# Initialize Notion Search
|
||||||
if (t == None or t in state.SearchType.Notion.value) and content_config.notion and search_models.text_search:
|
notion_config = NotionConfig.objects.filter(user=user).first()
|
||||||
|
if (t == None or t in state.SearchType.Notion.value) and notion_config:
|
||||||
logger.info("🔌 Setting up search for notion")
|
logger.info("🔌 Setting up search for notion")
|
||||||
content_index.notion = text_search.setup(
|
text_search.setup(
|
||||||
NotionToJsonl,
|
NotionToJsonl,
|
||||||
None,
|
None,
|
||||||
content_config.notion,
|
|
||||||
search_models.text_search.bi_encoder,
|
|
||||||
regenerate=regenerate,
|
regenerate=regenerate,
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
|
||||||
full_corpus=full_corpus,
|
full_corpus=full_corpus,
|
||||||
|
user=user,
|
||||||
|
config=notion_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
|
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
|
||||||
|
|
||||||
try:
|
|
||||||
# Initialize External Plugin Search
|
|
||||||
if t == None and content_config.plugins and search_models.text_search:
|
|
||||||
logger.info("🔌 Setting up search for plugins")
|
|
||||||
content_index.plugins = {}
|
|
||||||
for plugin_type, plugin_config in content_config.plugins.items():
|
|
||||||
content_index.plugins[plugin_type] = text_search.setup(
|
|
||||||
JsonlToJsonl,
|
|
||||||
None,
|
|
||||||
plugin_config,
|
|
||||||
search_models.text_search.bi_encoder,
|
|
||||||
regenerate=regenerate,
|
|
||||||
filters=[DateFilter(), WordFilter(), FileFilter()],
|
|
||||||
full_corpus=full_corpus,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"🚨 Failed to setup Plugin: {e}", exc_info=True)
|
|
||||||
|
|
||||||
# Invalidate Query Cache
|
# Invalidate Query Cache
|
||||||
state.query_cache = LRU()
|
if user:
|
||||||
|
state.query_cache[user.uuid] = LRU()
|
||||||
|
|
||||||
return content_index
|
return content_index
|
||||||
|
|
||||||
|
@ -412,44 +329,9 @@ def load_content(
|
||||||
if content_index is None:
|
if content_index is None:
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
|
|
||||||
if content_config.org:
|
|
||||||
logger.info("🦄 Loading orgmode notes")
|
|
||||||
content_index.org = text_search.load(content_config.org, filters=[DateFilter(), WordFilter(), FileFilter()])
|
|
||||||
if content_config.markdown:
|
|
||||||
logger.info("💎 Loading markdown notes")
|
|
||||||
content_index.markdown = text_search.load(
|
|
||||||
content_config.markdown, filters=[DateFilter(), WordFilter(), FileFilter()]
|
|
||||||
)
|
|
||||||
if content_config.pdf:
|
|
||||||
logger.info("🖨️ Loading pdf")
|
|
||||||
content_index.pdf = text_search.load(content_config.pdf, filters=[DateFilter(), WordFilter(), FileFilter()])
|
|
||||||
if content_config.plaintext:
|
|
||||||
logger.info("📄 Loading plaintext")
|
|
||||||
content_index.plaintext = text_search.load(
|
|
||||||
content_config.plaintext, filters=[DateFilter(), WordFilter(), FileFilter()]
|
|
||||||
)
|
|
||||||
if content_config.image:
|
if content_config.image:
|
||||||
logger.info("🌄 Loading images")
|
logger.info("🌄 Loading images")
|
||||||
content_index.image = image_search.setup(
|
content_index.image = image_search.setup(
|
||||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||||
)
|
)
|
||||||
if content_config.github:
|
|
||||||
logger.info("🐙 Loading github")
|
|
||||||
content_index.github = text_search.load(
|
|
||||||
content_config.github, filters=[DateFilter(), WordFilter(), FileFilter()]
|
|
||||||
)
|
|
||||||
if content_config.notion:
|
|
||||||
logger.info("🔌 Loading notion")
|
|
||||||
content_index.notion = text_search.load(
|
|
||||||
content_config.notion, filters=[DateFilter(), WordFilter(), FileFilter()]
|
|
||||||
)
|
|
||||||
if content_config.plugins:
|
|
||||||
logger.info("🔌 Loading plugins")
|
|
||||||
content_index.plugins = {}
|
|
||||||
for plugin_type, plugin_config in content_config.plugins.items():
|
|
||||||
content_index.plugins[plugin_type] = text_search.load(
|
|
||||||
plugin_config, filters=[DateFilter(), WordFilter(), FileFilter()]
|
|
||||||
)
|
|
||||||
|
|
||||||
state.query_cache = LRU()
|
|
||||||
return content_index
|
return content_index
|
||||||
|
|
|
@ -3,10 +3,20 @@ from fastapi import APIRouter
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import HTMLResponse, FileResponse
|
from fastapi.responses import HTMLResponse, FileResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from khoj.utils.rawconfig import TextContentConfig, OpenAIProcessorConfig, FullConfig
|
from starlette.authentication import requires
|
||||||
|
from khoj.utils.rawconfig import (
|
||||||
|
TextContentConfig,
|
||||||
|
OpenAIProcessorConfig,
|
||||||
|
FullConfig,
|
||||||
|
GithubContentConfig,
|
||||||
|
GithubRepoConfig,
|
||||||
|
NotionContentConfig,
|
||||||
|
)
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
|
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
|
||||||
|
from database.models import KhojUser, LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
@ -29,10 +39,23 @@ def chat_page(request: Request):
|
||||||
return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo})
|
return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo})
|
||||||
|
|
||||||
|
|
||||||
|
def map_config_to_object(content_type: str):
|
||||||
|
if content_type == "org":
|
||||||
|
return LocalOrgConfig
|
||||||
|
if content_type == "markdown":
|
||||||
|
return LocalMarkdownConfig
|
||||||
|
if content_type == "pdf":
|
||||||
|
return LocalPdfConfig
|
||||||
|
if content_type == "plaintext":
|
||||||
|
return LocalPlaintextConfig
|
||||||
|
|
||||||
|
|
||||||
if not state.demo:
|
if not state.demo:
|
||||||
|
|
||||||
@web_client.get("/config", response_class=HTMLResponse)
|
@web_client.get("/config", response_class=HTMLResponse)
|
||||||
def config_page(request: Request):
|
def config_page(request: Request):
|
||||||
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
|
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
|
||||||
default_full_config = FullConfig(
|
default_full_config = FullConfig(
|
||||||
content_type=None,
|
content_type=None,
|
||||||
search_type=None,
|
search_type=None,
|
||||||
|
@ -41,13 +64,13 @@ if not state.demo:
|
||||||
current_config = state.config or json.loads(default_full_config.json())
|
current_config = state.config or json.loads(default_full_config.json())
|
||||||
|
|
||||||
successfully_configured = {
|
successfully_configured = {
|
||||||
"pdf": False,
|
"pdf": ("pdf" in enabled_content),
|
||||||
"markdown": False,
|
"markdown": ("markdown" in enabled_content),
|
||||||
"org": False,
|
"org": ("org" in enabled_content),
|
||||||
"image": False,
|
"image": False,
|
||||||
"github": False,
|
"github": ("github" in enabled_content),
|
||||||
"notion": False,
|
"notion": ("notion" in enabled_content),
|
||||||
"plaintext": False,
|
"plaintext": ("plaintext" in enabled_content),
|
||||||
"enable_offline_model": False,
|
"enable_offline_model": False,
|
||||||
"conversation_openai": False,
|
"conversation_openai": False,
|
||||||
"conversation_gpt4all": False,
|
"conversation_gpt4all": False,
|
||||||
|
@ -56,13 +79,7 @@ if not state.demo:
|
||||||
if state.content_index:
|
if state.content_index:
|
||||||
successfully_configured.update(
|
successfully_configured.update(
|
||||||
{
|
{
|
||||||
"pdf": state.content_index.pdf is not None,
|
|
||||||
"markdown": state.content_index.markdown is not None,
|
|
||||||
"org": state.content_index.org is not None,
|
|
||||||
"image": state.content_index.image is not None,
|
"image": state.content_index.image is not None,
|
||||||
"github": state.content_index.github is not None,
|
|
||||||
"notion": state.content_index.notion is not None,
|
|
||||||
"plaintext": state.content_index.plaintext is not None,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -84,22 +101,29 @@ if not state.demo:
|
||||||
)
|
)
|
||||||
|
|
||||||
@web_client.get("/config/content_type/github", response_class=HTMLResponse)
|
@web_client.get("/config/content_type/github", response_class=HTMLResponse)
|
||||||
|
@requires(["authenticated"])
|
||||||
def github_config_page(request: Request):
|
def github_config_page(request: Request):
|
||||||
default_copy = constants.default_config.copy()
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
default_github = default_copy["content-type"]["github"] # type: ignore
|
current_github_config = get_user_github_config(user)
|
||||||
|
|
||||||
default_config = TextContentConfig(
|
if current_github_config:
|
||||||
compressed_jsonl=default_github["compressed-jsonl"],
|
raw_repos = current_github_config.githubrepoconfig.all()
|
||||||
embeddings_file=default_github["embeddings-file"],
|
repos = []
|
||||||
)
|
for repo in raw_repos:
|
||||||
|
repos.append(
|
||||||
current_config = (
|
GithubRepoConfig(
|
||||||
state.config.content_type.github
|
name=repo.name,
|
||||||
if state.config and state.config.content_type and state.config.content_type.github
|
owner=repo.owner,
|
||||||
else default_config
|
branch=repo.branch,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
current_config = json.loads(current_config.json())
|
current_config = GithubContentConfig(
|
||||||
|
pat_token=current_github_config.pat_token,
|
||||||
|
repos=repos,
|
||||||
|
)
|
||||||
|
current_config = json.loads(current_config.json())
|
||||||
|
else:
|
||||||
|
current_config = {} # type: ignore
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
"content_type_github_input.html", context={"request": request, "current_config": current_config}
|
"content_type_github_input.html", context={"request": request, "current_config": current_config}
|
||||||
|
@ -107,18 +131,11 @@ if not state.demo:
|
||||||
|
|
||||||
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
|
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
|
||||||
def notion_config_page(request: Request):
|
def notion_config_page(request: Request):
|
||||||
default_copy = constants.default_config.copy()
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
default_notion = default_copy["content-type"]["notion"] # type: ignore
|
current_notion_config = get_user_notion_config(user)
|
||||||
|
|
||||||
default_config = TextContentConfig(
|
current_config = NotionContentConfig(
|
||||||
compressed_jsonl=default_notion["compressed-jsonl"],
|
token=current_notion_config.token if current_notion_config else "",
|
||||||
embeddings_file=default_notion["embeddings-file"],
|
|
||||||
)
|
|
||||||
|
|
||||||
current_config = (
|
|
||||||
state.config.content_type.notion
|
|
||||||
if state.config and state.config.content_type and state.config.content_type.notion
|
|
||||||
else default_config
|
|
||||||
)
|
)
|
||||||
|
|
||||||
current_config = json.loads(current_config.json())
|
current_config = json.loads(current_config.json())
|
||||||
|
@ -132,18 +149,16 @@ if not state.demo:
|
||||||
if content_type not in VALID_TEXT_CONTENT_TYPES:
|
if content_type not in VALID_TEXT_CONTENT_TYPES:
|
||||||
return templates.TemplateResponse("config.html", context={"request": request})
|
return templates.TemplateResponse("config.html", context={"request": request})
|
||||||
|
|
||||||
default_copy = constants.default_config.copy()
|
object = map_config_to_object(content_type)
|
||||||
default_content_type = default_copy["content-type"][content_type] # type: ignore
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
|
config = object.objects.filter(user=user).first()
|
||||||
|
if config == None:
|
||||||
|
config = object.objects.create(user=user)
|
||||||
|
|
||||||
default_config = TextContentConfig(
|
current_config = TextContentConfig(
|
||||||
compressed_jsonl=default_content_type["compressed-jsonl"],
|
input_files=config.input_files,
|
||||||
embeddings_file=default_content_type["embeddings-file"],
|
input_filter=config.input_filter,
|
||||||
)
|
index_heading_entries=config.index_heading_entries,
|
||||||
|
|
||||||
current_config = (
|
|
||||||
state.config.content_type[content_type]
|
|
||||||
if state.config and state.config.content_type and state.config.content_type[content_type] # type: ignore
|
|
||||||
else default_config
|
|
||||||
)
|
)
|
||||||
current_config = json.loads(current_config.json())
|
current_config = json.loads(current_config.json())
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,9 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Set, Tuple
|
from typing import List
|
||||||
|
|
||||||
# Internal Packages
|
|
||||||
from khoj.utils.rawconfig import Entry
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFilter(ABC):
|
class BaseFilter(ABC):
|
||||||
@abstractmethod
|
|
||||||
def load(self, entries: List[Entry], *args, **kwargs):
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_filter_terms(self, query: str) -> List[str]:
|
def get_filter_terms(self, query: str) -> List[str]:
|
||||||
...
|
...
|
||||||
|
@ -18,10 +11,6 @@ class BaseFilter(ABC):
|
||||||
def can_filter(self, raw_query: str) -> bool:
|
def can_filter(self, raw_query: str) -> bool:
|
||||||
return len(self.get_filter_terms(raw_query)) > 0
|
return len(self.get_filter_terms(raw_query)) > 0
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def defilter(self, query: str) -> str:
|
def defilter(self, query: str) -> str:
|
||||||
...
|
...
|
||||||
|
|
|
@ -25,72 +25,42 @@ class DateFilter(BaseFilter):
|
||||||
# - dt>="last week"
|
# - dt>="last week"
|
||||||
# - dt:"2 years ago"
|
# - dt:"2 years ago"
|
||||||
date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']"
|
date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']"
|
||||||
|
raw_date_regex = r"\d{4}-\d{2}-\d{2}"
|
||||||
|
|
||||||
def __init__(self, entry_key="compiled"):
|
def __init__(self, entry_key="compiled"):
|
||||||
self.entry_key = entry_key
|
self.entry_key = entry_key
|
||||||
self.date_to_entry_ids = defaultdict(set)
|
self.date_to_entry_ids = defaultdict(set)
|
||||||
self.cache = LRU()
|
self.cache = LRU()
|
||||||
|
|
||||||
def load(self, entries, *args, **kwargs):
|
def extract_dates(self, content):
|
||||||
with timer("Created date filter index", logger):
|
pattern_matched_dates = re.findall(self.raw_date_regex, content)
|
||||||
for id, entry in enumerate(entries):
|
|
||||||
# Extract dates from entry
|
# Filter down to valid dates
|
||||||
for date_in_entry_string in re.findall(r"\d{4}-\d{2}-\d{2}", getattr(entry, self.entry_key)):
|
valid_dates = []
|
||||||
# Convert date string in entry to unix timestamp
|
for date_str in pattern_matched_dates:
|
||||||
try:
|
try:
|
||||||
date_in_entry = datetime.strptime(date_in_entry_string, "%Y-%m-%d").timestamp()
|
valid_dates.append(datetime.strptime(date_str, "%Y-%m-%d"))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
except OSError:
|
|
||||||
logger.debug(f"OSError: Ignoring unprocessable date in entry: {date_in_entry_string}")
|
return valid_dates
|
||||||
continue
|
|
||||||
self.date_to_entry_ids[date_in_entry].add(id)
|
|
||||||
|
|
||||||
def get_filter_terms(self, query: str) -> List[str]:
|
def get_filter_terms(self, query: str) -> List[str]:
|
||||||
"Get all filter terms in query"
|
"Get all filter terms in query"
|
||||||
return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)]
|
return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)]
|
||||||
|
|
||||||
|
def get_query_date_range(self, query) -> List:
|
||||||
|
with timer("Extract date range to filter from query", logger):
|
||||||
|
query_daterange = self.extract_date_range(query)
|
||||||
|
|
||||||
|
return query_daterange
|
||||||
|
|
||||||
def defilter(self, query):
|
def defilter(self, query):
|
||||||
# remove date range filter from query
|
# remove date range filter from query
|
||||||
query = re.sub(rf"\s+{self.date_regex}", " ", query)
|
query = re.sub(rf"\s+{self.date_regex}", " ", query)
|
||||||
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
|
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
|
||||||
return query
|
return query
|
||||||
|
|
||||||
def apply(self, query, entries):
|
|
||||||
"Find entries containing any dates that fall within date range specified in query"
|
|
||||||
# extract date range specified in date filter of query
|
|
||||||
with timer("Extract date range to filter from query", logger):
|
|
||||||
query_daterange = self.extract_date_range(query)
|
|
||||||
|
|
||||||
# if no date in query, return all entries
|
|
||||||
if query_daterange == []:
|
|
||||||
return query, set(range(len(entries)))
|
|
||||||
|
|
||||||
query = self.defilter(query)
|
|
||||||
|
|
||||||
# return results from cache if exists
|
|
||||||
cache_key = tuple(query_daterange)
|
|
||||||
if cache_key in self.cache:
|
|
||||||
logger.debug(f"Return date filter results from cache")
|
|
||||||
entries_to_include = self.cache[cache_key]
|
|
||||||
return query, entries_to_include
|
|
||||||
|
|
||||||
if not self.date_to_entry_ids:
|
|
||||||
self.load(entries)
|
|
||||||
|
|
||||||
# find entries containing any dates that fall with date range specified in query
|
|
||||||
with timer("Mark entries satisfying filter", logger):
|
|
||||||
entries_to_include = set()
|
|
||||||
for date_in_entry in self.date_to_entry_ids.keys():
|
|
||||||
# Check if date in entry is within date range specified in query
|
|
||||||
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
|
||||||
entries_to_include |= self.date_to_entry_ids[date_in_entry]
|
|
||||||
|
|
||||||
# cache results
|
|
||||||
self.cache[cache_key] = entries_to_include
|
|
||||||
|
|
||||||
return query, entries_to_include
|
|
||||||
|
|
||||||
def extract_date_range(self, query):
|
def extract_date_range(self, query):
|
||||||
# find date range filter in query
|
# find date range filter in query
|
||||||
date_range_matches = re.findall(self.date_regex, query)
|
date_range_matches = re.findall(self.date_regex, query)
|
||||||
|
@ -138,6 +108,15 @@ class DateFilter(BaseFilter):
|
||||||
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
|
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
|
# If the first element is 0, replace it with None
|
||||||
|
|
||||||
|
if effective_date_range[0] == 0:
|
||||||
|
effective_date_range[0] = None
|
||||||
|
|
||||||
|
# If the second element is inf, replace it with None
|
||||||
|
if effective_date_range[1] == inf:
|
||||||
|
effective_date_range[1] = None
|
||||||
|
|
||||||
return effective_date_range
|
return effective_date_range
|
||||||
|
|
||||||
def parse(self, date_str, relative_base=None):
|
def parse(self, date_str, relative_base=None):
|
||||||
|
|
|
@ -21,62 +21,13 @@ class FileFilter(BaseFilter):
|
||||||
self.file_to_entry_map = defaultdict(set)
|
self.file_to_entry_map = defaultdict(set)
|
||||||
self.cache = LRU()
|
self.cache = LRU()
|
||||||
|
|
||||||
def load(self, entries, *args, **kwargs):
|
|
||||||
with timer("Created file filter index", logger):
|
|
||||||
for id, entry in enumerate(entries):
|
|
||||||
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
|
|
||||||
|
|
||||||
def get_filter_terms(self, query: str) -> List[str]:
|
def get_filter_terms(self, query: str) -> List[str]:
|
||||||
"Get all filter terms in query"
|
"Get all filter terms in query"
|
||||||
return [f'file:"{term}"' for term in re.findall(self.file_filter_regex, query)]
|
return [f"{self.convert_to_regex(term)}" for term in re.findall(self.file_filter_regex, query)]
|
||||||
|
|
||||||
|
def convert_to_regex(self, file_filter: str) -> str:
|
||||||
|
"Convert file filter to regex"
|
||||||
|
return file_filter.replace(".", r"\.").replace("*", r".*")
|
||||||
|
|
||||||
def defilter(self, query: str) -> str:
|
def defilter(self, query: str) -> str:
|
||||||
return re.sub(self.file_filter_regex, "", query).strip()
|
return re.sub(self.file_filter_regex, "", query).strip()
|
||||||
|
|
||||||
def apply(self, query, entries):
|
|
||||||
# Extract file filters from raw query
|
|
||||||
with timer("Extract files_to_search from query", logger):
|
|
||||||
raw_files_to_search = re.findall(self.file_filter_regex, query)
|
|
||||||
if not raw_files_to_search:
|
|
||||||
return query, set(range(len(entries)))
|
|
||||||
|
|
||||||
# Convert simple file filters with no path separator into regex
|
|
||||||
# e.g. "file:notes.org" -> "file:.*notes.org"
|
|
||||||
files_to_search = []
|
|
||||||
for file in sorted(raw_files_to_search):
|
|
||||||
if "/" not in file and "\\" not in file and "*" not in file:
|
|
||||||
files_to_search += [f"*{file}"]
|
|
||||||
else:
|
|
||||||
files_to_search += [file]
|
|
||||||
|
|
||||||
# Remove filter terms from original query
|
|
||||||
query = self.defilter(query)
|
|
||||||
|
|
||||||
# Return item from cache if exists
|
|
||||||
cache_key = tuple(files_to_search)
|
|
||||||
if cache_key in self.cache:
|
|
||||||
logger.debug(f"Return file filter results from cache")
|
|
||||||
included_entry_indices = self.cache[cache_key]
|
|
||||||
return query, included_entry_indices
|
|
||||||
|
|
||||||
if not self.file_to_entry_map:
|
|
||||||
self.load(entries, regenerate=False)
|
|
||||||
|
|
||||||
# Mark entries that contain any blocked_words for exclusion
|
|
||||||
with timer("Mark entries satisfying filter", logger):
|
|
||||||
included_entry_indices = set.union(
|
|
||||||
*[
|
|
||||||
self.file_to_entry_map[entry_file]
|
|
||||||
for entry_file in self.file_to_entry_map.keys()
|
|
||||||
for search_file in files_to_search
|
|
||||||
if fnmatch.fnmatch(entry_file, search_file)
|
|
||||||
],
|
|
||||||
set(),
|
|
||||||
)
|
|
||||||
if not included_entry_indices:
|
|
||||||
return query, {}
|
|
||||||
|
|
||||||
# Cache results
|
|
||||||
self.cache[cache_key] = included_entry_indices
|
|
||||||
|
|
||||||
return query, included_entry_indices
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import List
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.search_filter.base_filter import BaseFilter
|
from khoj.search_filter.base_filter import BaseFilter
|
||||||
from khoj.utils.helpers import LRU, timer
|
from khoj.utils.helpers import LRU
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -22,21 +22,6 @@ class WordFilter(BaseFilter):
|
||||||
self.word_to_entry_index = defaultdict(set)
|
self.word_to_entry_index = defaultdict(set)
|
||||||
self.cache = LRU()
|
self.cache = LRU()
|
||||||
|
|
||||||
def load(self, entries, *args, **kwargs):
|
|
||||||
with timer("Created word filter index", logger):
|
|
||||||
self.cache = {} # Clear cache on filter (re-)load
|
|
||||||
entry_splitter = (
|
|
||||||
r",|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'"
|
|
||||||
)
|
|
||||||
# Create map of words to entries they exist in
|
|
||||||
for entry_index, entry in enumerate(entries):
|
|
||||||
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
|
|
||||||
if word == "":
|
|
||||||
continue
|
|
||||||
self.word_to_entry_index[word].add(entry_index)
|
|
||||||
|
|
||||||
return self.word_to_entry_index
|
|
||||||
|
|
||||||
def get_filter_terms(self, query: str) -> List[str]:
|
def get_filter_terms(self, query: str) -> List[str]:
|
||||||
"Get all filter terms in query"
|
"Get all filter terms in query"
|
||||||
required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)]
|
required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)]
|
||||||
|
@ -45,47 +30,3 @@ class WordFilter(BaseFilter):
|
||||||
|
|
||||||
def defilter(self, query: str) -> str:
|
def defilter(self, query: str) -> str:
|
||||||
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
||||||
|
|
||||||
def apply(self, query, entries):
|
|
||||||
"Find entries containing required and not blocked words specified in query"
|
|
||||||
# Separate natural query from required, blocked words filters
|
|
||||||
with timer("Extract required, blocked filters from query", logger):
|
|
||||||
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
|
||||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
|
||||||
query = self.defilter(query)
|
|
||||||
|
|
||||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
|
||||||
return query, set(range(len(entries)))
|
|
||||||
|
|
||||||
# Return item from cache if exists
|
|
||||||
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
|
|
||||||
if cache_key in self.cache:
|
|
||||||
logger.debug(f"Return word filter results from cache")
|
|
||||||
included_entry_indices = self.cache[cache_key]
|
|
||||||
return query, included_entry_indices
|
|
||||||
|
|
||||||
if not self.word_to_entry_index:
|
|
||||||
self.load(entries, regenerate=False)
|
|
||||||
|
|
||||||
# mark entries that contain all required_words for inclusion
|
|
||||||
with timer("Mark entries satisfying filter", logger):
|
|
||||||
entries_with_all_required_words = set(range(len(entries)))
|
|
||||||
if len(required_words) > 0:
|
|
||||||
entries_with_all_required_words = set.intersection(
|
|
||||||
*[self.word_to_entry_index.get(word, set()) for word in required_words]
|
|
||||||
)
|
|
||||||
|
|
||||||
# mark entries that contain any blocked_words for exclusion
|
|
||||||
entries_with_any_blocked_words = set()
|
|
||||||
if len(blocked_words) > 0:
|
|
||||||
entries_with_any_blocked_words = set.union(
|
|
||||||
*[self.word_to_entry_index.get(word, set()) for word in blocked_words]
|
|
||||||
)
|
|
||||||
|
|
||||||
# get entries satisfying inclusion and exclusion filters
|
|
||||||
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
|
|
||||||
|
|
||||||
# Cache results
|
|
||||||
self.cache[cache_key] = included_entry_indices
|
|
||||||
|
|
||||||
return query, included_entry_indices
|
|
||||||
|
|
|
@ -2,25 +2,39 @@
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Type, Union
|
from typing import List, Tuple, Type, Union, Dict
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
|
||||||
from khoj.search_filter.base_filter import BaseFilter
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
|
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
|
||||||
from khoj.utils.config import TextContent, TextSearchModel
|
from khoj.utils.config import TextSearchModel
|
||||||
from khoj.utils.models import BaseEncoder
|
from khoj.utils.models import BaseEncoder
|
||||||
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
|
from khoj.utils.state import SearchType
|
||||||
|
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, Entry
|
||||||
from khoj.utils.jsonl import load_jsonl
|
from khoj.utils.jsonl import load_jsonl
|
||||||
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
|
from database.adapters import EmbeddingsAdapters
|
||||||
|
from database.models import KhojUser, Embeddings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
search_type_to_embeddings_type = {
|
||||||
|
SearchType.Org.value: Embeddings.EmbeddingsType.ORG,
|
||||||
|
SearchType.Markdown.value: Embeddings.EmbeddingsType.MARKDOWN,
|
||||||
|
SearchType.Plaintext.value: Embeddings.EmbeddingsType.PLAINTEXT,
|
||||||
|
SearchType.Pdf.value: Embeddings.EmbeddingsType.PDF,
|
||||||
|
SearchType.Github.value: Embeddings.EmbeddingsType.GITHUB,
|
||||||
|
SearchType.Notion.value: Embeddings.EmbeddingsType.NOTION,
|
||||||
|
SearchType.All.value: None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def initialize_model(search_config: TextSearchConfig):
|
def initialize_model(search_config: TextSearchConfig):
|
||||||
"Initialize model for semantic search on text"
|
"Initialize model for semantic search on text"
|
||||||
|
@ -117,171 +131,102 @@ def load_embeddings(
|
||||||
|
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
|
user: KhojUser,
|
||||||
raw_query: str,
|
raw_query: str,
|
||||||
search_model: TextSearchModel,
|
type: SearchType = SearchType.All,
|
||||||
content: TextContent,
|
|
||||||
question_embedding: Union[torch.Tensor, None] = None,
|
question_embedding: Union[torch.Tensor, None] = None,
|
||||||
rank_results: bool = False,
|
rank_results: bool = False,
|
||||||
score_threshold: float = -math.inf,
|
score_threshold: float = -math.inf,
|
||||||
dedupe: bool = True,
|
|
||||||
) -> Tuple[List[dict], List[Entry]]:
|
) -> Tuple[List[dict], List[Entry]]:
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
if (
|
|
||||||
content.entries is None
|
|
||||||
or len(content.entries) == 0
|
|
||||||
or content.corpus_embeddings is None
|
|
||||||
or len(content.corpus_embeddings) == 0
|
|
||||||
):
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
|
file_type = search_type_to_embeddings_type[type.value]
|
||||||
|
|
||||||
# Filter query, entries and embeddings before semantic search
|
query = raw_query
|
||||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
|
|
||||||
|
|
||||||
# If no entries left after filtering, return empty results
|
|
||||||
if entries is None or len(entries) == 0:
|
|
||||||
return [], []
|
|
||||||
# If query only had filters it'll be empty now. So short-circuit and return results.
|
|
||||||
if query.strip() == "":
|
|
||||||
hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)]
|
|
||||||
return hits, entries
|
|
||||||
|
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
if question_embedding is None:
|
if question_embedding is None:
|
||||||
with timer("Query Encode Time", logger, state.device):
|
with timer("Query Encode Time", logger, state.device):
|
||||||
question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
question_embedding = state.embeddings_model.embed_query(query)
|
||||||
question_embedding = util.normalize_embeddings(question_embedding)
|
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
|
top_k = 10
|
||||||
with timer("Search Time", logger, state.device):
|
with timer("Search Time", logger, state.device):
|
||||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
|
hits = EmbeddingsAdapters.search_with_embeddings(
|
||||||
|
user=user,
|
||||||
|
embeddings=question_embedding,
|
||||||
|
max_results=top_k,
|
||||||
|
file_type_filter=file_type,
|
||||||
|
raw_query=raw_query,
|
||||||
|
).all()
|
||||||
|
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
return hits
|
||||||
|
|
||||||
|
|
||||||
|
def collate_results(hits, dedupe=True):
|
||||||
|
hit_ids = set()
|
||||||
|
for hit in hits:
|
||||||
|
if dedupe and hit.corpus_id in hit_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
hit_ids.add(hit.corpus_id)
|
||||||
|
yield SearchResponse.parse_obj(
|
||||||
|
{
|
||||||
|
"entry": hit.raw,
|
||||||
|
"score": hit.distance,
|
||||||
|
"additional": {
|
||||||
|
"file": hit.file_path,
|
||||||
|
"compiled": hit.compiled,
|
||||||
|
"heading": hit.heading,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rerank_and_sort_results(hits, query):
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
if rank_results and search_model.cross_encoder:
|
hits = cross_encoder_score(query, hits)
|
||||||
hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
|
|
||||||
|
|
||||||
# Filter results by score threshold
|
# Sort results by cross-encoder score followed by bi-encoder score
|
||||||
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
|
hits = sort_results(rank_results=True, hits=hits)
|
||||||
|
|
||||||
# Order results by cross-encoder score followed by bi-encoder score
|
return hits
|
||||||
hits = sort_results(rank_results, hits)
|
|
||||||
|
|
||||||
# Deduplicate entries by raw entry text before showing to users
|
|
||||||
if dedupe:
|
|
||||||
hits = deduplicate_results(entries, hits)
|
|
||||||
|
|
||||||
return hits, entries
|
|
||||||
|
|
||||||
|
|
||||||
def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]:
|
|
||||||
return [
|
|
||||||
SearchResponse.parse_obj(
|
|
||||||
{
|
|
||||||
"entry": entries[hit["corpus_id"]].raw,
|
|
||||||
"score": f"{hit.get('cross-score') or hit.get('score')}",
|
|
||||||
"additional": {
|
|
||||||
"file": entries[hit["corpus_id"]].file,
|
|
||||||
"compiled": entries[hit["corpus_id"]].compiled,
|
|
||||||
"heading": entries[hit["corpus_id"]].heading,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for hit in hits[0:count]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def setup(
|
def setup(
|
||||||
text_to_jsonl: Type[TextToJsonl],
|
text_to_jsonl: Type[TextEmbeddings],
|
||||||
files: dict[str, str],
|
files: dict[str, str],
|
||||||
config: TextConfigBase,
|
|
||||||
bi_encoder: BaseEncoder,
|
|
||||||
regenerate: bool,
|
regenerate: bool,
|
||||||
filters: List[BaseFilter] = [],
|
|
||||||
normalize: bool = True,
|
|
||||||
full_corpus: bool = True,
|
full_corpus: bool = True,
|
||||||
) -> TextContent:
|
user: KhojUser = None,
|
||||||
# Map notes in text files to (compressed) JSONL formatted file
|
config=None,
|
||||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
) -> None:
|
||||||
previous_entries = []
|
if config:
|
||||||
if config.compressed_jsonl.exists() and not regenerate:
|
num_new_embeddings, num_deleted_embeddings = text_to_jsonl(config).process(
|
||||||
previous_entries = extract_entries(config.compressed_jsonl)
|
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
|
||||||
entries_with_indices = text_to_jsonl(config).process(
|
)
|
||||||
previous_entries=previous_entries, files=files, full_corpus=full_corpus
|
else:
|
||||||
)
|
num_new_embeddings, num_deleted_embeddings = text_to_jsonl().process(
|
||||||
|
files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
|
||||||
# Extract Updated Entries
|
|
||||||
entries = extract_entries(config.compressed_jsonl)
|
|
||||||
if is_none_or_empty(entries):
|
|
||||||
config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
|
|
||||||
raise ValueError(
|
|
||||||
f"No valid entries found in specified configuration: {config_params}, with files: {files.keys()}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute or Load Embeddings
|
file_names = [file_name for file_name in files]
|
||||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
|
||||||
corpus_embeddings = compute_embeddings(
|
logger.info(
|
||||||
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize
|
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for filter in filters:
|
|
||||||
filter.load(entries, regenerate=regenerate)
|
|
||||||
|
|
||||||
return TextContent(entries, corpus_embeddings, filters)
|
def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:
|
||||||
|
|
||||||
|
|
||||||
def load(
|
|
||||||
config: TextConfigBase,
|
|
||||||
filters: List[BaseFilter] = [],
|
|
||||||
) -> TextContent:
|
|
||||||
# Map notes in text files to (compressed) JSONL formatted file
|
|
||||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
|
||||||
entries = extract_entries(config.compressed_jsonl)
|
|
||||||
|
|
||||||
# Compute or Load Embeddings
|
|
||||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
|
||||||
corpus_embeddings = load_embeddings(config.embeddings_file)
|
|
||||||
|
|
||||||
for filter in filters:
|
|
||||||
filter.load(entries, regenerate=False)
|
|
||||||
|
|
||||||
return TextContent(entries, corpus_embeddings, filters)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_filters(
|
|
||||||
query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]
|
|
||||||
) -> Tuple[str, List[Entry], torch.Tensor]:
|
|
||||||
"""Filter query, entries and embeddings before semantic search"""
|
|
||||||
|
|
||||||
with timer("Total Filter Time", logger, state.device):
|
|
||||||
included_entry_indices = set(range(len(entries)))
|
|
||||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
|
||||||
for filter in filters_in_query:
|
|
||||||
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
|
||||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
|
||||||
|
|
||||||
# Get entries (and associated embeddings) satisfying all filters
|
|
||||||
if not included_entry_indices:
|
|
||||||
return "", [], torch.tensor([], device=state.device)
|
|
||||||
else:
|
|
||||||
entries = [entries[id] for id in included_entry_indices]
|
|
||||||
corpus_embeddings = torch.index_select(
|
|
||||||
corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)
|
|
||||||
)
|
|
||||||
|
|
||||||
return query, entries, corpus_embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]:
|
|
||||||
"""Score all retrieved entries using the cross-encoder"""
|
"""Score all retrieved entries using the cross-encoder"""
|
||||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||||
cross_inp = [[query, entries[hit["corpus_id"]].compiled] for hit in hits]
|
cross_scores = state.cross_encoder_model.predict(query, hits)
|
||||||
cross_scores = cross_encoder.predict(cross_inp)
|
|
||||||
|
|
||||||
# Store cross-encoder scores in results dictionary for ranking
|
# Store cross-encoder scores in results dictionary for ranking
|
||||||
for idx in range(len(cross_scores)):
|
for idx in range(len(cross_scores)):
|
||||||
hits[idx]["cross-score"] = cross_scores[idx]
|
hits[idx]["cross_score"] = cross_scores[idx]
|
||||||
|
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
|
@ -291,23 +236,5 @@ def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
|
||||||
with timer("Rank Time", logger, state.device):
|
with timer("Rank Time", logger, state.device):
|
||||||
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
|
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
|
||||||
if rank_results:
|
if rank_results:
|
||||||
hits.sort(key=lambda x: x["cross-score"], reverse=True) # sort by cross-encoder score
|
hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
|
||||||
return hits
|
|
||||||
|
|
||||||
|
|
||||||
def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]:
|
|
||||||
"""Deduplicate entries by raw entry text before showing to users
|
|
||||||
Compiled entries are split by max tokens supported by ML models.
|
|
||||||
This can result in duplicate hits, entries shown to user."""
|
|
||||||
|
|
||||||
with timer("Deduplication Time", logger, state.device):
|
|
||||||
seen, original_hits_count = set(), len(hits)
|
|
||||||
hits = [
|
|
||||||
hit
|
|
||||||
for hit in hits
|
|
||||||
if entries[hit["corpus_id"]].raw not in seen and not seen.add(entries[hit["corpus_id"]].raw) # type: ignore[func-returns-value]
|
|
||||||
]
|
|
||||||
duplicate_hits = original_hits_count - len(hits)
|
|
||||||
|
|
||||||
logger.debug(f"Removed {duplicate_hits} duplicates")
|
|
||||||
return hits
|
return hits
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
import pathlib
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
import os
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.helpers import resolve_absolute_path
|
from khoj.utils.helpers import resolve_absolute_path
|
||||||
|
@ -34,6 +35,12 @@ def cli(args=None):
|
||||||
)
|
)
|
||||||
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
|
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
|
||||||
parser.add_argument("--demo", action="store_true", default=False, help="Run Khoj in demo mode")
|
parser.add_argument("--demo", action="store_true", default=False, help="Run Khoj in demo mode")
|
||||||
|
parser.add_argument(
|
||||||
|
"--anonymous-mode",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Run Khoj in anonymous mode. This does not require any login for connecting users.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args(args)
|
args = parser.parse_args(args)
|
||||||
|
|
||||||
|
@ -51,6 +58,8 @@ def cli(args=None):
|
||||||
else:
|
else:
|
||||||
args = run_migrations(args)
|
args = run_migrations(args)
|
||||||
args.config = parse_config_from_file(args.config_file)
|
args.config = parse_config_from_file(args.config_file)
|
||||||
|
if os.environ.get("DEBUG"):
|
||||||
|
args.config.app.should_log_telemetry = False
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
|
@ -41,9 +41,7 @@ class ProcessorType(str, Enum):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TextContent:
|
class TextContent:
|
||||||
entries: List[Entry]
|
enabled: bool
|
||||||
corpus_embeddings: torch.Tensor
|
|
||||||
filters: List[BaseFilter]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -67,21 +65,13 @@ class ImageSearchModel:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ContentIndex:
|
class ContentIndex:
|
||||||
org: Optional[TextContent] = None
|
|
||||||
markdown: Optional[TextContent] = None
|
|
||||||
pdf: Optional[TextContent] = None
|
|
||||||
github: Optional[TextContent] = None
|
|
||||||
notion: Optional[TextContent] = None
|
|
||||||
image: Optional[ImageContent] = None
|
image: Optional[ImageContent] = None
|
||||||
plaintext: Optional[TextContent] = None
|
|
||||||
plugins: Optional[Dict[str, TextContent]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchModels:
|
class SearchModels:
|
||||||
text_search: Optional[TextSearchModel] = None
|
text_search: Optional[TextSearchModel] = None
|
||||||
image_search: Optional[ImageSearchModel] = None
|
image_search: Optional[ImageSearchModel] = None
|
||||||
plugin_search: Optional[Dict[str, TextSearchModel]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -5,6 +5,7 @@ web_directory = app_root_directory / "khoj/interface/web/"
|
||||||
empty_escape_sequences = "\n|\r|\t| "
|
empty_escape_sequences = "\n|\r|\t| "
|
||||||
app_env_filepath = "~/.khoj/env"
|
app_env_filepath = "~/.khoj/env"
|
||||||
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
||||||
|
content_directory = "~/.khoj/content/"
|
||||||
|
|
||||||
empty_config = {
|
empty_config = {
|
||||||
"content-type": {
|
"content-type": {
|
||||||
|
|
|
@ -5,29 +5,39 @@ from typing import Optional
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty
|
from khoj.utils.helpers import get_absolute_path, is_none_or_empty
|
||||||
from khoj.utils.rawconfig import TextContentConfig, ContentConfig
|
from khoj.utils.rawconfig import TextContentConfig
|
||||||
from khoj.utils.config import SearchType
|
from khoj.utils.config import SearchType
|
||||||
|
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def collect_files(config: ContentConfig, search_type: Optional[SearchType] = SearchType.All):
|
def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict:
|
||||||
files = {}
|
files = {}
|
||||||
|
|
||||||
if config is None:
|
|
||||||
return files
|
|
||||||
|
|
||||||
if search_type == SearchType.All or search_type == SearchType.Org:
|
if search_type == SearchType.All or search_type == SearchType.Org:
|
||||||
files["org"] = get_org_files(config.org) if config.org else {}
|
org_config = LocalOrgConfig.objects.filter(user=user).first()
|
||||||
|
files["org"] = get_org_files(construct_config_from_db(org_config)) if org_config else {}
|
||||||
if search_type == SearchType.All or search_type == SearchType.Markdown:
|
if search_type == SearchType.All or search_type == SearchType.Markdown:
|
||||||
files["markdown"] = get_markdown_files(config.markdown) if config.markdown else {}
|
markdown_config = LocalMarkdownConfig.objects.filter(user=user).first()
|
||||||
|
files["markdown"] = get_markdown_files(construct_config_from_db(markdown_config)) if markdown_config else {}
|
||||||
if search_type == SearchType.All or search_type == SearchType.Plaintext:
|
if search_type == SearchType.All or search_type == SearchType.Plaintext:
|
||||||
files["plaintext"] = get_plaintext_files(config.plaintext) if config.plaintext else {}
|
plaintext_config = LocalPlaintextConfig.objects.filter(user=user).first()
|
||||||
|
files["plaintext"] = get_plaintext_files(construct_config_from_db(plaintext_config)) if plaintext_config else {}
|
||||||
if search_type == SearchType.All or search_type == SearchType.Pdf:
|
if search_type == SearchType.All or search_type == SearchType.Pdf:
|
||||||
files["pdf"] = get_pdf_files(config.pdf) if config.pdf else {}
|
pdf_config = LocalPdfConfig.objects.filter(user=user).first()
|
||||||
|
files["pdf"] = get_pdf_files(construct_config_from_db(pdf_config)) if pdf_config else {}
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def construct_config_from_db(db_config) -> TextContentConfig:
|
||||||
|
return TextContentConfig(
|
||||||
|
input_files=db_config.input_files,
|
||||||
|
input_filter=db_config.input_filter,
|
||||||
|
index_heading_entries=db_config.index_heading_entries,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_plaintext_files(config: TextContentConfig) -> dict[str, str]:
|
def get_plaintext_files(config: TextContentConfig) -> dict[str, str]:
|
||||||
def is_plaintextfile(file: str):
|
def is_plaintextfile(file: str):
|
||||||
"Check if file is plaintext file"
|
"Check if file is plaintext file"
|
||||||
|
|
|
@ -209,10 +209,12 @@ def log_telemetry(
|
||||||
if not app_config or not app_config.should_log_telemetry:
|
if not app_config or not app_config.should_log_telemetry:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
if properties.get("server_id") is None:
|
||||||
|
properties["server_id"] = get_server_id()
|
||||||
|
|
||||||
# Populate telemetry data to log
|
# Populate telemetry data to log
|
||||||
request_body = {
|
request_body = {
|
||||||
"telemetry_type": telemetry_type,
|
"telemetry_type": telemetry_type,
|
||||||
"server_id": get_server_id(),
|
|
||||||
"server_version": version("khoj-assistant"),
|
"server_version": version("khoj-assistant"),
|
||||||
"os": platform.system(),
|
"os": platform.system(),
|
||||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
# System Packages
|
# System Packages
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Optional, Union, Any
|
from typing import List, Dict, Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty
|
from khoj.utils.helpers import to_snake_case_from_dash
|
||||||
|
|
||||||
|
|
||||||
class ConfigBase(BaseModel):
|
class ConfigBase(BaseModel):
|
||||||
|
@ -27,7 +28,7 @@ class TextConfigBase(ConfigBase):
|
||||||
embeddings_file: Path
|
embeddings_file: Path
|
||||||
|
|
||||||
|
|
||||||
class TextContentConfig(TextConfigBase):
|
class TextContentConfig(ConfigBase):
|
||||||
input_files: Optional[List[Path]]
|
input_files: Optional[List[Path]]
|
||||||
input_filter: Optional[List[str]]
|
input_filter: Optional[List[str]]
|
||||||
index_heading_entries: Optional[bool] = False
|
index_heading_entries: Optional[bool] = False
|
||||||
|
@ -39,12 +40,12 @@ class GithubRepoConfig(ConfigBase):
|
||||||
branch: Optional[str] = "master"
|
branch: Optional[str] = "master"
|
||||||
|
|
||||||
|
|
||||||
class GithubContentConfig(TextConfigBase):
|
class GithubContentConfig(ConfigBase):
|
||||||
pat_token: str
|
pat_token: str
|
||||||
repos: List[GithubRepoConfig]
|
repos: List[GithubRepoConfig]
|
||||||
|
|
||||||
|
|
||||||
class NotionContentConfig(TextConfigBase):
|
class NotionContentConfig(ConfigBase):
|
||||||
token: str
|
token: str
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,7 +64,6 @@ class ContentConfig(ConfigBase):
|
||||||
pdf: Optional[TextContentConfig]
|
pdf: Optional[TextContentConfig]
|
||||||
plaintext: Optional[TextContentConfig]
|
plaintext: Optional[TextContentConfig]
|
||||||
github: Optional[GithubContentConfig]
|
github: Optional[GithubContentConfig]
|
||||||
plugins: Optional[Dict[str, TextContentConfig]]
|
|
||||||
notion: Optional[NotionContentConfig]
|
notion: Optional[NotionContentConfig]
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,7 +122,8 @@ class FullConfig(ConfigBase):
|
||||||
|
|
||||||
class SearchResponse(ConfigBase):
|
class SearchResponse(ConfigBase):
|
||||||
entry: str
|
entry: str
|
||||||
score: str
|
score: float
|
||||||
|
cross_score: Optional[float]
|
||||||
additional: Optional[dict]
|
additional: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
|
@ -131,14 +132,21 @@ class Entry:
|
||||||
compiled: str
|
compiled: str
|
||||||
heading: Optional[str]
|
heading: Optional[str]
|
||||||
file: Optional[str]
|
file: Optional[str]
|
||||||
|
corpus_id: str
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, raw: str = None, compiled: str = None, heading: Optional[str] = None, file: Optional[str] = None
|
self,
|
||||||
|
raw: str = None,
|
||||||
|
compiled: str = None,
|
||||||
|
heading: Optional[str] = None,
|
||||||
|
file: Optional[str] = None,
|
||||||
|
corpus_id: uuid.UUID = None,
|
||||||
):
|
):
|
||||||
self.raw = raw
|
self.raw = raw
|
||||||
self.compiled = compiled
|
self.compiled = compiled
|
||||||
self.heading = heading
|
self.heading = heading
|
||||||
self.file = file
|
self.file = file
|
||||||
|
self.corpus_id = str(corpus_id)
|
||||||
|
|
||||||
def to_json(self) -> str:
|
def to_json(self) -> str:
|
||||||
return json.dumps(self.__dict__, ensure_ascii=False)
|
return json.dumps(self.__dict__, ensure_ascii=False)
|
||||||
|
@ -153,4 +161,5 @@ class Entry:
|
||||||
compiled=dictionary["compiled"],
|
compiled=dictionary["compiled"],
|
||||||
file=dictionary.get("file", None),
|
file=dictionary.get("file", None),
|
||||||
heading=dictionary.get("heading", None),
|
heading=dictionary.get("heading", None),
|
||||||
|
corpus_id=dictionary.get("corpus_id", None),
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
import threading
|
import threading
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
|
@ -12,10 +13,13 @@ from khoj.utils import config as utils_config
|
||||||
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
|
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
|
||||||
from khoj.utils.helpers import LRU
|
from khoj.utils.helpers import LRU
|
||||||
from khoj.utils.rawconfig import FullConfig
|
from khoj.utils.rawconfig import FullConfig
|
||||||
|
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
|
||||||
|
|
||||||
# Application Global State
|
# Application Global State
|
||||||
config = FullConfig()
|
config = FullConfig()
|
||||||
search_models = SearchModels()
|
search_models = SearchModels()
|
||||||
|
embeddings_model = EmbeddingsModel()
|
||||||
|
cross_encoder_model = CrossEncoderModel()
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
processor_config = ProcessorConfigModel()
|
processor_config = ProcessorConfigModel()
|
||||||
config_file: Path = None
|
config_file: Path = None
|
||||||
|
@ -23,14 +27,14 @@ verbose: int = 0
|
||||||
host: str = None
|
host: str = None
|
||||||
port: int = None
|
port: int = None
|
||||||
cli_args: List[str] = None
|
cli_args: List[str] = None
|
||||||
query_cache = LRU()
|
query_cache: Dict[str, LRU] = defaultdict(LRU)
|
||||||
config_lock = threading.Lock()
|
config_lock = threading.Lock()
|
||||||
chat_lock = threading.Lock()
|
chat_lock = threading.Lock()
|
||||||
SearchType = utils_config.SearchType
|
SearchType = utils_config.SearchType
|
||||||
telemetry: List[Dict[str, str]] = []
|
telemetry: List[Dict[str, str]] = []
|
||||||
previous_query: str = None
|
|
||||||
demo: bool = False
|
demo: bool = False
|
||||||
khoj_version: str = None
|
khoj_version: str = None
|
||||||
|
anonymous_mode: bool = False
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
# Use CUDA GPU
|
# Use CUDA GPU
|
||||||
|
|
|
@ -1,15 +1,19 @@
|
||||||
# External Packages
|
# External Packages
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi import FastAPI
|
||||||
|
import factory
|
||||||
|
import os
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from app.main import app
|
|
||||||
from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
|
from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
|
||||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
|
||||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.utils.config import SearchModels
|
from khoj.utils.config import SearchModels
|
||||||
|
@ -22,8 +26,6 @@ from khoj.utils.rawconfig import (
|
||||||
OpenAIProcessorConfig,
|
OpenAIProcessorConfig,
|
||||||
ProcessorConfig,
|
ProcessorConfig,
|
||||||
TextContentConfig,
|
TextContentConfig,
|
||||||
GithubContentConfig,
|
|
||||||
GithubRepoConfig,
|
|
||||||
ImageContentConfig,
|
ImageContentConfig,
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
TextSearchConfig,
|
TextSearchConfig,
|
||||||
|
@ -31,11 +33,31 @@ from khoj.utils.rawconfig import (
|
||||||
)
|
)
|
||||||
from khoj.utils import state, fs_syncer
|
from khoj.utils import state, fs_syncer
|
||||||
from khoj.routers.indexer import configure_content
|
from khoj.routers.indexer import configure_content
|
||||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from database.models import (
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
LocalOrgConfig,
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
LocalMarkdownConfig,
|
||||||
|
LocalPlaintextConfig,
|
||||||
|
LocalPdfConfig,
|
||||||
|
GithubConfig,
|
||||||
|
KhojUser,
|
||||||
|
GithubRepoConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def enable_db_access_for_all_tests(db):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UserFactory(factory.django.DjangoModelFactory):
|
||||||
|
class Meta:
|
||||||
|
model = KhojUser
|
||||||
|
|
||||||
|
username = factory.Faker("name")
|
||||||
|
email = factory.Faker("email")
|
||||||
|
password = factory.Faker("password")
|
||||||
|
uuid = factory.Faker("uuid4")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -67,17 +89,28 @@ def search_config() -> SearchConfig:
|
||||||
return search_config
|
return search_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.fixture
|
||||||
|
def default_user():
|
||||||
|
return UserFactory()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def search_models(search_config: SearchConfig):
|
def search_models(search_config: SearchConfig):
|
||||||
search_models = SearchModels()
|
search_models = SearchModels()
|
||||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
|
||||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
|
||||||
return search_models
|
return search_models
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture
|
||||||
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
|
def anyio_backend():
|
||||||
|
return "asyncio"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser):
|
||||||
content_dir = tmp_path_factory.mktemp("content")
|
content_dir = tmp_path_factory.mktemp("content")
|
||||||
|
|
||||||
# Generate Image Embeddings from Test Images
|
# Generate Image Embeddings from Test Images
|
||||||
|
@ -92,94 +125,45 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
|
||||||
|
|
||||||
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
|
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
|
||||||
|
|
||||||
# Generate Notes Embeddings from Test Notes
|
LocalOrgConfig.objects.create(
|
||||||
content_config.org = TextContentConfig(
|
|
||||||
input_files=None,
|
input_files=None,
|
||||||
input_filter=["tests/data/org/*.org"],
|
input_filter=["tests/data/org/*.org"],
|
||||||
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
|
index_heading_entries=False,
|
||||||
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
|
user=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
text_search.setup(OrgToJsonl, get_sample_data("org"), regenerate=False, user=default_user)
|
||||||
text_search.setup(
|
|
||||||
OrgToJsonl,
|
|
||||||
get_sample_data("org"),
|
|
||||||
content_config.org,
|
|
||||||
search_models.text_search.bi_encoder,
|
|
||||||
regenerate=False,
|
|
||||||
filters=filters,
|
|
||||||
)
|
|
||||||
|
|
||||||
content_config.plugins = {
|
|
||||||
"plugin1": TextContentConfig(
|
|
||||||
input_files=[content_dir.joinpath("notes.jsonl.gz")],
|
|
||||||
input_filter=None,
|
|
||||||
compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
|
|
||||||
embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if os.getenv("GITHUB_PAT_TOKEN"):
|
if os.getenv("GITHUB_PAT_TOKEN"):
|
||||||
content_config.github = GithubContentConfig(
|
GithubConfig.objects.create(
|
||||||
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
|
pat_token=os.getenv("GITHUB_PAT_TOKEN"),
|
||||||
repos=[
|
user=default_user,
|
||||||
GithubRepoConfig(
|
|
||||||
owner="khoj-ai",
|
|
||||||
name="lantern",
|
|
||||||
branch="master",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
|
|
||||||
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
content_config.plaintext = TextContentConfig(
|
GithubRepoConfig.objects.create(
|
||||||
|
owner="khoj-ai",
|
||||||
|
name="lantern",
|
||||||
|
branch="master",
|
||||||
|
github_config=GithubConfig.objects.get(user=default_user),
|
||||||
|
)
|
||||||
|
|
||||||
|
LocalPlaintextConfig.objects.create(
|
||||||
input_files=None,
|
input_files=None,
|
||||||
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
|
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
|
||||||
compressed_jsonl=content_dir.joinpath("plaintext.jsonl.gz"),
|
user=default_user,
|
||||||
embeddings_file=content_dir.joinpath("plaintext_embeddings.pt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
content_config.github = GithubContentConfig(
|
|
||||||
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
|
|
||||||
repos=[
|
|
||||||
GithubRepoConfig(
|
|
||||||
owner="khoj-ai",
|
|
||||||
name="lantern",
|
|
||||||
branch="master",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
|
|
||||||
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
|
||||||
text_search.setup(
|
|
||||||
JsonlToJsonl,
|
|
||||||
None,
|
|
||||||
content_config.plugins["plugin1"],
|
|
||||||
search_models.text_search.bi_encoder,
|
|
||||||
regenerate=False,
|
|
||||||
filters=filters,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return content_config
|
return content_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def md_content_config(tmp_path_factory):
|
def md_content_config():
|
||||||
content_dir = tmp_path_factory.mktemp("content")
|
markdown_config = LocalMarkdownConfig.objects.create(
|
||||||
|
|
||||||
# Generate Embeddings for Markdown Content
|
|
||||||
content_config = ContentConfig()
|
|
||||||
content_config.markdown = TextContentConfig(
|
|
||||||
input_files=None,
|
input_files=None,
|
||||||
input_filter=["tests/data/markdown/*.markdown"],
|
input_filter=["tests/data/markdown/*.markdown"],
|
||||||
compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"),
|
|
||||||
embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return content_config
|
return markdown_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -220,19 +204,20 @@ def processor_config_offline_chat(tmp_path_factory):
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
|
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
|
||||||
# Initialize app state
|
# Initialize app state
|
||||||
state.config.content_type = md_content_config
|
|
||||||
state.config.search_type = search_config
|
state.config.search_type = search_config
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types(state.config)
|
||||||
|
|
||||||
# Index Markdown Content for Search
|
# Index Markdown Content for Search
|
||||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
all_files = fs_syncer.collect_files()
|
||||||
all_files = fs_syncer.collect_files(state.config.content_type)
|
|
||||||
state.content_index = configure_content(
|
state.content_index = configure_content(
|
||||||
state.content_index, state.config.content_type, all_files, state.search_models
|
state.content_index, state.config.content_type, all_files, state.search_models
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
state.processor_config = configure_processor(processor_config)
|
state.processor_config = configure_processor(processor_config)
|
||||||
|
state.anonymous_mode = True
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
configure_middleware(app)
|
configure_middleware(app)
|
||||||
|
@ -241,33 +226,45 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def client(content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
|
def fastapi_app():
|
||||||
|
app = FastAPI()
|
||||||
|
configure_routes(app)
|
||||||
|
configure_middleware(app)
|
||||||
|
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def client(
|
||||||
|
content_config: ContentConfig,
|
||||||
|
search_config: SearchConfig,
|
||||||
|
processor_config: ProcessorConfig,
|
||||||
|
default_user: KhojUser,
|
||||||
|
):
|
||||||
state.config.content_type = content_config
|
state.config.content_type = content_config
|
||||||
state.config.search_type = search_config
|
state.config.search_type = search_config
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types(state.config)
|
||||||
|
|
||||||
# These lines help us Mock the Search models for these search types
|
# These lines help us Mock the Search models for these search types
|
||||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
|
||||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
state.content_index.org = text_search.setup(
|
text_search.setup(
|
||||||
OrgToJsonl,
|
OrgToJsonl,
|
||||||
get_sample_data("org"),
|
get_sample_data("org"),
|
||||||
content_config.org,
|
|
||||||
state.search_models.text_search.bi_encoder,
|
|
||||||
regenerate=False,
|
regenerate=False,
|
||||||
|
user=default_user,
|
||||||
)
|
)
|
||||||
state.content_index.image = image_search.setup(
|
state.content_index.image = image_search.setup(
|
||||||
content_config.image, state.search_models.image_search, regenerate=False
|
content_config.image, state.search_models.image_search, regenerate=False
|
||||||
)
|
)
|
||||||
state.content_index.plaintext = text_search.setup(
|
text_search.setup(
|
||||||
PlaintextToJsonl,
|
PlaintextToJsonl,
|
||||||
get_sample_data("plaintext"),
|
get_sample_data("plaintext"),
|
||||||
content_config.plaintext,
|
|
||||||
state.search_models.text_search.bi_encoder,
|
|
||||||
regenerate=False,
|
regenerate=False,
|
||||||
|
user=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
state.processor_config = configure_processor(processor_config)
|
state.processor_config = configure_processor(processor_config)
|
||||||
|
state.anonymous_mode = True
|
||||||
|
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
configure_middleware(app)
|
configure_middleware(app)
|
||||||
|
@ -288,7 +285,6 @@ def client_offline_chat(
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types(state.config)
|
||||||
|
|
||||||
# Index Markdown Content for Search
|
# Index Markdown Content for Search
|
||||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
|
||||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
|
||||||
all_files = fs_syncer.collect_files(state.config.content_type)
|
all_files = fs_syncer.collect_files(state.config.content_type)
|
||||||
|
@ -298,6 +294,7 @@ def client_offline_chat(
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
state.processor_config = configure_processor(processor_config_offline_chat)
|
state.processor_config = configure_processor(processor_config_offline_chat)
|
||||||
|
state.anonymous_mode = True
|
||||||
|
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
configure_middleware(app)
|
configure_middleware(app)
|
||||||
|
@ -306,9 +303,11 @@ def client_offline_chat(
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def new_org_file(content_config: ContentConfig):
|
def new_org_file(default_user: KhojUser, content_config: ContentConfig):
|
||||||
# Setup
|
# Setup
|
||||||
new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org"
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
|
input_filters = org_config.input_filter
|
||||||
|
new_org_file = Path(input_filters[0]).parent / "new_file.org"
|
||||||
new_org_file.touch()
|
new_org_file.touch()
|
||||||
|
|
||||||
yield new_org_file
|
yield new_org_file
|
||||||
|
@ -319,11 +318,9 @@ def new_org_file(content_config: ContentConfig):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path):
|
def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
|
||||||
new_org_config = deepcopy(content_config.org)
|
LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None)
|
||||||
new_org_config.input_files = [f"{new_org_file}"]
|
return LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
new_org_config.input_filter = None
|
|
||||||
return new_org_config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
|
|
11
tests/data/config.yml
vendored
11
tests/data/config.yml
vendored
|
@ -9,17 +9,6 @@ content-type:
|
||||||
input-filter:
|
input-filter:
|
||||||
- '*.org'
|
- '*.org'
|
||||||
- ~/notes/*.org
|
- ~/notes/*.org
|
||||||
plugins:
|
|
||||||
content_plugin_1:
|
|
||||||
compressed-jsonl: content_plugin_1.jsonl.gz
|
|
||||||
embeddings-file: content_plugin_1_embeddings.pt
|
|
||||||
input-files:
|
|
||||||
- content_plugin_1_new.jsonl.gz
|
|
||||||
content_plugin_2:
|
|
||||||
compressed-jsonl: content_plugin_2.jsonl.gz
|
|
||||||
embeddings-file: content_plugin_2_embeddings.pt
|
|
||||||
input-filter:
|
|
||||||
- '*2_new.jsonl.gz'
|
|
||||||
enable-offline-chat: false
|
enable-offline-chat: false
|
||||||
search-type:
|
search-type:
|
||||||
asymmetric:
|
asymmetric:
|
||||||
|
|
|
@ -48,14 +48,3 @@ def test_cli_config_from_file():
|
||||||
Path("~/first_from_config.org"),
|
Path("~/first_from_config.org"),
|
||||||
Path("~/second_from_config.org"),
|
Path("~/second_from_config.org"),
|
||||||
]
|
]
|
||||||
assert len(actual_args.config.content_type.plugins.keys()) == 2
|
|
||||||
assert actual_args.config.content_type.plugins["content_plugin_1"].input_files == [
|
|
||||||
Path("content_plugin_1_new.jsonl.gz")
|
|
||||||
]
|
|
||||||
assert actual_args.config.content_type.plugins["content_plugin_2"].input_filter == ["*2_new.jsonl.gz"]
|
|
||||||
assert actual_args.config.content_type.plugins["content_plugin_1"].compressed_jsonl == Path(
|
|
||||||
"content_plugin_1.jsonl.gz"
|
|
||||||
)
|
|
||||||
assert actual_args.config.content_type.plugins["content_plugin_2"].embeddings_file == Path(
|
|
||||||
"content_plugin_2_embeddings.pt"
|
|
||||||
)
|
|
||||||
|
|
|
@ -2,22 +2,21 @@
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
import pytest
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
import pytest
|
from fastapi import FastAPI
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from app.main import app
|
|
||||||
from khoj.configure import configure_routes, configure_search_types
|
from khoj.configure import configure_routes, configure_search_types
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.state import search_models, content_index, config
|
from khoj.utils.state import search_models, content_index, config
|
||||||
from khoj.search_type import text_search, image_search
|
from khoj.search_type import text_search, image_search
|
||||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from database.models import KhojUser
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from database.adapters import EmbeddingsAdapters
|
||||||
|
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
|
@ -35,7 +34,7 @@ def test_search_with_invalid_content_type(client):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_search_with_valid_content_type(client):
|
def test_search_with_valid_content_type(client):
|
||||||
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plugin1"]:
|
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion"]:
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/api/search?q=random&t={content_type}")
|
response = client.get(f"/api/search?q=random&t={content_type}")
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -75,7 +74,7 @@ def test_index_update(client):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_regenerate_with_valid_content_type(client):
|
def test_regenerate_with_valid_content_type(client):
|
||||||
for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]:
|
for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]:
|
||||||
# Arrange
|
# Arrange
|
||||||
files = get_sample_files_data()
|
files = get_sample_files_data()
|
||||||
headers = {"x-api-key": "secret"}
|
headers = {"x-api-key": "secret"}
|
||||||
|
@ -102,60 +101,42 @@ def test_regenerate_with_github_fails_without_pat(client):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db
|
||||||
@pytest.mark.skip(reason="Flaky test on parallel test runs")
|
@pytest.mark.skip(reason="Flaky test on parallel test runs")
|
||||||
def test_get_configured_types_via_api(client):
|
def test_get_configured_types_via_api(client, sample_org_data):
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/api/config/types")
|
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||||
|
|
||||||
|
enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert list(enabled_types) == ["org"]
|
||||||
assert response.json() == ["all", "org", "image", "plaintext", "plugin1"]
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_get_configured_types_with_only_plugin_content_config(content_config):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data):
|
||||||
# Arrange
|
# Arrange
|
||||||
config.content_type = ContentConfig()
|
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||||
config.content_type.plugins = content_config.plugins
|
|
||||||
state.SearchType = configure_search_types(config)
|
|
||||||
|
|
||||||
configure_routes(app)
|
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/api/config/types")
|
response = client.get(f"/api/config/types")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == ["all", "plugin1"]
|
assert response.json() == ["all", "org", "markdown", "image"]
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_get_configured_types_with_no_plugin_content_config(content_config):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
||||||
# Arrange
|
# Arrange
|
||||||
config.content_type = content_config
|
|
||||||
config.content_type.plugins = None
|
|
||||||
state.SearchType = configure_search_types(config)
|
state.SearchType = configure_search_types(config)
|
||||||
|
original_config = state.config.content_type
|
||||||
|
state.config.content_type = None
|
||||||
|
|
||||||
configure_routes(app)
|
configure_routes(fastapi_app)
|
||||||
client = TestClient(app)
|
client = TestClient(fastapi_app)
|
||||||
|
|
||||||
# Act
|
|
||||||
response = client.get(f"/api/config/types")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert "plugin1" not in response.json()
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
def test_get_configured_types_with_no_content_config():
|
|
||||||
# Arrange
|
|
||||||
config.content_type = ContentConfig()
|
|
||||||
state.SearchType = configure_search_types(config)
|
|
||||||
|
|
||||||
configure_routes(app)
|
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/api/config/types")
|
response = client.get(f"/api/config/types")
|
||||||
|
@ -164,6 +145,9 @@ def test_get_configured_types_with_no_content_config():
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == ["all"]
|
assert response.json() == ["all"]
|
||||||
|
|
||||||
|
# Restore
|
||||||
|
state.config.content_type = original_config
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||||
|
@ -192,12 +176,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_notes_search(client, search_config: SearchConfig, sample_org_data):
|
||||||
# Arrange
|
# Arrange
|
||||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||||
content_index.org = text_search.setup(
|
|
||||||
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
|
||||||
)
|
|
||||||
user_query = quote("How to git install application?")
|
user_query = quote("How to git install application?")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -211,19 +193,15 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_notes_search_with_only_filters(
|
def test_notes_search_with_only_filters(
|
||||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
filters = [WordFilter(), FileFilter()]
|
text_search.setup(
|
||||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
|
||||||
content_index.org = text_search.setup(
|
|
||||||
OrgToJsonl,
|
OrgToJsonl,
|
||||||
sample_org_data,
|
sample_org_data,
|
||||||
content_config.org,
|
|
||||||
search_models.text_search.bi_encoder,
|
|
||||||
regenerate=False,
|
regenerate=False,
|
||||||
filters=filters,
|
|
||||||
)
|
)
|
||||||
user_query = quote('+"Emacs" file:"*.org"')
|
user_query = quote('+"Emacs" file:"*.org"')
|
||||||
|
|
||||||
|
@ -238,15 +216,10 @@ def test_notes_search_with_only_filters(
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_notes_search_with_include_filter(
|
@pytest.mark.django_db(transaction=True)
|
||||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
def test_notes_search_with_include_filter(client, sample_org_data):
|
||||||
):
|
|
||||||
# Arrange
|
# Arrange
|
||||||
filters = [WordFilter()]
|
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
|
||||||
content_index.org = text_search.setup(
|
|
||||||
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search, regenerate=False, filters=filters
|
|
||||||
)
|
|
||||||
user_query = quote('How to git install application? +"Emacs"')
|
user_query = quote('How to git install application? +"Emacs"')
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
@ -260,19 +233,13 @@ def test_notes_search_with_include_filter(
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_notes_search_with_exclude_filter(
|
@pytest.mark.django_db(transaction=True)
|
||||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
def test_notes_search_with_exclude_filter(client, sample_org_data):
|
||||||
):
|
|
||||||
# Arrange
|
# Arrange
|
||||||
filters = [WordFilter()]
|
text_search.setup(
|
||||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
|
||||||
content_index.org = text_search.setup(
|
|
||||||
OrgToJsonl,
|
OrgToJsonl,
|
||||||
sample_org_data,
|
sample_org_data,
|
||||||
content_config.org,
|
|
||||||
search_models.text_search.bi_encoder,
|
|
||||||
regenerate=False,
|
regenerate=False,
|
||||||
filters=filters,
|
|
||||||
)
|
)
|
||||||
user_query = quote('How to git install application? -"clone"')
|
user_query = quote('How to git install application? -"clone"')
|
||||||
|
|
||||||
|
@ -286,6 +253,22 @@ def test_notes_search_with_exclude_filter(
|
||||||
assert "clone" not in search_result
|
assert "clone" not in search_result
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
|
||||||
|
# Arrange
|
||||||
|
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
|
||||||
|
user_query = quote("How to git install application?")
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
# assert actual response has no data as the default_user is different from the user making the query (anonymous)
|
||||||
|
assert len(response.json()) == 0
|
||||||
|
|
||||||
|
|
||||||
def get_sample_files_data():
|
def get_sample_files_data():
|
||||||
return {
|
return {
|
||||||
"files": ("path/to/filename.org", "* practicing piano", "text/org"),
|
"files": ("path/to/filename.org", "* practicing piano", "text/org"),
|
||||||
|
|
|
@ -1,53 +1,12 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from math import inf
|
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.utils.rawconfig import Entry
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
|
|
||||||
def test_date_filter():
|
|
||||||
entries = [
|
|
||||||
Entry(compiled="Entry with no date", raw="Entry with no date"),
|
|
||||||
Entry(compiled="April Fools entry: 1984-04-01", raw="April Fools entry: 1984-04-01"),
|
|
||||||
Entry(compiled="Entry with date:1984-04-02", raw="Entry with date:1984-04-02"),
|
|
||||||
]
|
|
||||||
|
|
||||||
q_with_no_date_filter = "head tail"
|
|
||||||
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
|
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {0, 1, 2}
|
|
||||||
|
|
||||||
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
|
||||||
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
|
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == set()
|
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
|
||||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {2}
|
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
|
||||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {1}
|
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
|
||||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {2}
|
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
|
||||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {1, 2}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
|
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
|
||||||
|
@ -56,8 +15,8 @@ def test_extract_date_range():
|
||||||
datetime(1984, 1, 5, 0, 0, 0).timestamp(),
|
datetime(1984, 1, 5, 0, 0, 0).timestamp(),
|
||||||
datetime(1984, 1, 7, 0, 0, 0).timestamp(),
|
datetime(1984, 1, 7, 0, 0, 0).timestamp(),
|
||||||
]
|
]
|
||||||
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [None, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||||
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
|
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), None]
|
||||||
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
|
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
|
||||||
datetime(1984, 1, 1, 0, 0, 0).timestamp(),
|
datetime(1984, 1, 1, 0, 0, 0).timestamp(),
|
||||||
datetime(1984, 1, 2, 0, 0, 0).timestamp(),
|
datetime(1984, 1, 2, 0, 0, 0).timestamp(),
|
||||||
|
|
|
@ -6,97 +6,73 @@ from khoj.utils.rawconfig import Entry
|
||||||
def test_no_file_filter():
|
def test_no_file_filter():
|
||||||
# Arrange
|
# Arrange
|
||||||
file_filter = FileFilter()
|
file_filter = FileFilter()
|
||||||
entries = arrange_content()
|
|
||||||
q_with_no_filter = "head tail"
|
q_with_no_filter = "head tail"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == False
|
assert can_filter == False
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {0, 1, 2, 3}
|
|
||||||
|
|
||||||
|
|
||||||
def test_file_filter_with_non_existent_file():
|
def test_file_filter_with_non_existent_file():
|
||||||
# Arrange
|
# Arrange
|
||||||
file_filter = FileFilter()
|
file_filter = FileFilter()
|
||||||
entries = arrange_content()
|
|
||||||
q_with_no_filter = 'head file:"nonexistent.org" tail'
|
q_with_no_filter = 'head file:"nonexistent.org" tail'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_file_filter():
|
def test_single_file_filter():
|
||||||
# Arrange
|
# Arrange
|
||||||
file_filter = FileFilter()
|
file_filter = FileFilter()
|
||||||
entries = arrange_content()
|
|
||||||
q_with_no_filter = 'head file:"file 1.org" tail'
|
q_with_no_filter = 'head file:"file 1.org" tail'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {0, 2}
|
|
||||||
|
|
||||||
|
|
||||||
def test_file_filter_with_partial_match():
|
def test_file_filter_with_partial_match():
|
||||||
# Arrange
|
# Arrange
|
||||||
file_filter = FileFilter()
|
file_filter = FileFilter()
|
||||||
entries = arrange_content()
|
|
||||||
q_with_no_filter = 'head file:"1.org" tail'
|
q_with_no_filter = 'head file:"1.org" tail'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {0, 2}
|
|
||||||
|
|
||||||
|
|
||||||
def test_file_filter_with_regex_match():
|
def test_file_filter_with_regex_match():
|
||||||
# Arrange
|
# Arrange
|
||||||
file_filter = FileFilter()
|
file_filter = FileFilter()
|
||||||
entries = arrange_content()
|
|
||||||
q_with_no_filter = 'head file:"*.org" tail'
|
q_with_no_filter = 'head file:"*.org" tail'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {0, 1, 2, 3}
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_file_filter():
|
def test_multiple_file_filter():
|
||||||
# Arrange
|
# Arrange
|
||||||
file_filter = FileFilter()
|
file_filter = FileFilter()
|
||||||
entries = arrange_content()
|
|
||||||
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
|
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {0, 1, 2, 3}
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_file_filter_terms():
|
def test_get_file_filter_terms():
|
||||||
|
@ -108,7 +84,7 @@ def test_get_file_filter_terms():
|
||||||
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
|
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"']
|
assert filter_terms == ["file 1\\.org", "/path/to/dir/.*\\.org"]
|
||||||
|
|
||||||
|
|
||||||
def arrange_content():
|
def arrange_content():
|
||||||
|
|
|
@ -1,78 +0,0 @@
|
||||||
# Internal Packages
|
|
||||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
|
||||||
from khoj.utils.rawconfig import Entry
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_entries_from_single_input_jsonl(tmp_path):
|
|
||||||
"Convert multiple jsonl entries from single file to entries."
|
|
||||||
# Arrange
|
|
||||||
input_jsonl = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}
|
|
||||||
{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}
|
|
||||||
"""
|
|
||||||
input_jsonl_file = create_file(tmp_path, input_jsonl)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file])
|
|
||||||
entries = list(map(Entry.from_dict, input_jsons))
|
|
||||||
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert len(entries) == 2
|
|
||||||
assert output_jsonl == input_jsonl
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_entries_from_multiple_input_jsonls(tmp_path):
|
|
||||||
"Convert multiple jsonl entries from single file to entries."
|
|
||||||
# Arrange
|
|
||||||
input_jsonl_1 = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}"""
|
|
||||||
input_jsonl_2 = """{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}"""
|
|
||||||
input_jsonl_file_1 = create_file(tmp_path, input_jsonl_1, filename="input1.jsonl")
|
|
||||||
input_jsonl_file_2 = create_file(tmp_path, input_jsonl_2, filename="input2.jsonl")
|
|
||||||
|
|
||||||
# Act
|
|
||||||
# Process Each Entry from All Notes Files
|
|
||||||
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file_1, input_jsonl_file_2])
|
|
||||||
entries = list(map(Entry.from_dict, input_jsons))
|
|
||||||
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert len(entries) == 2
|
|
||||||
assert output_jsonl == f"{input_jsonl_1}\n{input_jsonl_2}\n"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_jsonl_files(tmp_path):
|
|
||||||
"Ensure JSONL files specified via input-filter, input-files extracted"
|
|
||||||
# Arrange
|
|
||||||
# Include via input-filter globs
|
|
||||||
group1_file1 = create_file(tmp_path, filename="group1-file1.jsonl")
|
|
||||||
group1_file2 = create_file(tmp_path, filename="group1-file2.jsonl")
|
|
||||||
group2_file1 = create_file(tmp_path, filename="group2-file1.jsonl")
|
|
||||||
group2_file2 = create_file(tmp_path, filename="group2-file2.jsonl")
|
|
||||||
# Include via input-file field
|
|
||||||
file1 = create_file(tmp_path, filename="notes.jsonl")
|
|
||||||
# Not included by any filter
|
|
||||||
create_file(tmp_path, filename="not-included-jsonl.jsonl")
|
|
||||||
create_file(tmp_path, filename="not-included-text.txt")
|
|
||||||
|
|
||||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
|
|
||||||
|
|
||||||
# Setup input-files, input-filters
|
|
||||||
input_files = [tmp_path / "notes.jsonl"]
|
|
||||||
input_filter = [tmp_path / "group1*.jsonl", tmp_path / "group2*.jsonl"]
|
|
||||||
|
|
||||||
# Act
|
|
||||||
extracted_org_files = JsonlToJsonl.get_jsonl_files(input_files, input_filter)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert len(extracted_org_files) == 5
|
|
||||||
assert extracted_org_files == expected_files
|
|
||||||
|
|
||||||
|
|
||||||
# Helper Functions
|
|
||||||
def create_file(tmp_path, entry=None, filename="test.jsonl"):
|
|
||||||
jsonl_file = tmp_path / filename
|
|
||||||
jsonl_file.touch()
|
|
||||||
if entry:
|
|
||||||
jsonl_file.write_text(entry)
|
|
||||||
return jsonl_file
|
|
|
@ -4,7 +4,7 @@ import os
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from khoj.utils.helpers import is_none_or_empty
|
from khoj.utils.helpers import is_none_or_empty
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
from khoj.utils.fs_syncer import get_org_files
|
from khoj.utils.fs_syncer import get_org_files
|
||||||
|
@ -63,7 +63,7 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
|
||||||
|
|
||||||
# Split each entry from specified Org files by max words
|
# Split each entry from specified Org files by max words
|
||||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||||
TextToJsonl.split_entries_by_max_tokens(
|
TextEmbeddings.split_entries_by_max_tokens(
|
||||||
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
|
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -86,7 +86,7 @@ def test_entry_split_drops_large_words():
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Split entry by max words and drop words larger than max word length
|
# Split entry by max words and drop words larger than max word length
|
||||||
processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0]
|
processed_entry = TextEmbeddings.split_entries_by_max_tokens([entry], max_word_length=5)[0]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# "Heading" dropped from compiled version because its over the set max word limit
|
# "Heading" dropped from compiled version because its over the set max word limit
|
||||||
|
|
|
@ -7,6 +7,7 @@ from pathlib import Path
|
||||||
from khoj.utils.fs_syncer import get_plaintext_files
|
from khoj.utils.fs_syncer import get_plaintext_files
|
||||||
from khoj.utils.rawconfig import TextContentConfig
|
from khoj.utils.rawconfig import TextContentConfig
|
||||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||||
|
from database.models import LocalPlaintextConfig, KhojUser
|
||||||
|
|
||||||
|
|
||||||
def test_plaintext_file(tmp_path):
|
def test_plaintext_file(tmp_path):
|
||||||
|
@ -91,11 +92,12 @@ def test_get_plaintext_files(tmp_path):
|
||||||
assert set(extracted_plaintext_files.keys()) == set(expected_files)
|
assert set(extracted_plaintext_files.keys()) == set(expected_files)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_html_plaintext_file(content_config):
|
def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
|
||||||
"Ensure HTML files are parsed correctly"
|
"Ensure HTML files are parsed correctly"
|
||||||
# Arrange
|
# Arrange
|
||||||
# Setup input-files, input-filters
|
# Setup input-files, input-filters
|
||||||
extracted_plaintext_files = get_plaintext_files(content_config.plaintext)
|
config = LocalPlaintextConfig.objects.filter(user=default_user).first()
|
||||||
|
extracted_plaintext_files = get_plaintext_files(config=config)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files)
|
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files)
|
||||||
|
|
|
@ -3,23 +3,30 @@ import logging
|
||||||
import locale
|
import locale
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.state import content_index, search_models
|
|
||||||
from khoj.search_type import text_search
|
from khoj.search_type import text_search
|
||||||
|
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||||
from khoj.utils.config import SearchModels
|
from khoj.utils.config import SearchModels
|
||||||
from khoj.utils.fs_syncer import get_org_files
|
from khoj.utils.fs_syncer import get_org_files, collect_files
|
||||||
|
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
|
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
|
||||||
|
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig):
|
@pytest.mark.django_db
|
||||||
|
def test_text_search_setup_with_missing_file_raises_error(
|
||||||
|
org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig
|
||||||
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
# Ensure file mentioned in org.input-files is missing
|
# Ensure file mentioned in org.input-files is missing
|
||||||
single_new_file = Path(org_config_with_only_new_file.input_files[0])
|
single_new_file = Path(org_config_with_only_new_file.input_files[0])
|
||||||
|
@ -32,98 +39,126 @@ def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_n
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path: Path):
|
@pytest.mark.django_db
|
||||||
|
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
orgfile = tmp_path / "directory.org" / "file.org"
|
orgfile = tmp_path / "directory.org" / "file.org"
|
||||||
orgfile.parent.mkdir()
|
orgfile.parent.mkdir()
|
||||||
with open(orgfile, "w") as f:
|
with open(orgfile, "w") as f:
|
||||||
f.write("* Heading\n- List item\n")
|
f.write("* Heading\n- List item\n")
|
||||||
org_content_config = TextContentConfig(
|
|
||||||
input_filter=[f"{tmp_path}/**/*"], compressed_jsonl="test.jsonl", embeddings_file="test.pt"
|
LocalOrgConfig.objects.create(
|
||||||
|
input_filter=[f"{tmp_path}/**/*"],
|
||||||
|
input_files=None,
|
||||||
|
user=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
org_files = collect_files(user=default_user)["org"]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# should not raise IsADirectoryError and return orgfile
|
# should not raise IsADirectoryError and return orgfile
|
||||||
assert get_org_files(org_content_config) == {f"{orgfile}": "* Heading\n- List item\n"}
|
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db
|
||||||
def test_text_search_setup_with_empty_file_raises_error(
|
def test_text_search_setup_with_empty_file_raises_error(
|
||||||
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
data = get_org_files(org_config_with_only_new_file)
|
data = get_org_files(org_config_with_only_new_file)
|
||||||
# Act
|
# Act
|
||||||
# Generate notes embeddings during asymmetric setup
|
# Generate notes embeddings during asymmetric setup
|
||||||
with pytest.raises(ValueError, match=r"^No valid entries found*"):
|
with caplog.at_level(logging.INFO):
|
||||||
text_search.setup(OrgToJsonl, data, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
|
|
||||||
|
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||||
|
verify_embeddings(0, default_user)
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
|
@pytest.mark.django_db
|
||||||
|
def test_text_search_setup(content_config, default_user: KhojUser, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
data = get_org_files(content_config.org)
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
|
data = get_org_files(org_config)
|
||||||
# Act
|
with caplog.at_level(logging.INFO):
|
||||||
# Regenerate notes embeddings during asymmetric setup
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
notes_model = text_search.setup(
|
|
||||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(notes_model.entries) == 10
|
assert "Deleting all embeddings for file type org" in caplog.records[1].message
|
||||||
assert len(notes_model.corpus_embeddings) == 10
|
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog):
|
@pytest.mark.django_db
|
||||||
|
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
caplog.set_level(logging.INFO, logger="khoj")
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
|
data = get_org_files(org_config)
|
||||||
data = get_org_files(content_config.org)
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Generate initial notes embeddings during asymmetric setup
|
# Generate initial notes embeddings during asymmetric setup
|
||||||
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
|
with caplog.at_level(logging.INFO):
|
||||||
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
initial_logs = caplog.text
|
initial_logs = caplog.text
|
||||||
caplog.clear() # Clear logs
|
caplog.clear() # Clear logs
|
||||||
|
|
||||||
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
|
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
|
||||||
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
|
with caplog.at_level(logging.INFO):
|
||||||
|
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||||
final_logs = caplog.text
|
final_logs = caplog.text
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "Creating index from scratch." in initial_logs
|
assert "Deleting all embeddings for file type org" in initial_logs
|
||||||
assert "Creating index from scratch." not in final_logs
|
assert "Deleting all embeddings for file type org" not in final_logs
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_text_search(content_config: ContentConfig, search_config: SearchConfig):
|
# @pytest.mark.asyncio
|
||||||
|
async def test_text_search(search_config: SearchConfig):
|
||||||
# Arrange
|
# Arrange
|
||||||
data = get_org_files(content_config.org)
|
default_user = await KhojUser.objects.acreate(
|
||||||
|
username="test_user", password="test_password", email="test@example.com"
|
||||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
|
||||||
content_index.org = text_search.setup(
|
|
||||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
|
||||||
)
|
)
|
||||||
|
# Arrange
|
||||||
|
org_config = await LocalOrgConfig.objects.acreate(
|
||||||
|
input_files=None,
|
||||||
|
input_filter=["tests/data/org/*.org"],
|
||||||
|
index_heading_entries=False,
|
||||||
|
user=default_user,
|
||||||
|
)
|
||||||
|
data = get_org_files(org_config)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
text_search.setup,
|
||||||
|
OrgToJsonl,
|
||||||
|
data,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
default_user,
|
||||||
|
)
|
||||||
|
|
||||||
query = "How to git install application?"
|
query = "How to git install application?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
hits, entries = await text_search.query(
|
hits = await text_search.query(default_user, query)
|
||||||
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
|
|
||||||
)
|
|
||||||
|
|
||||||
results = text_search.collate_results(hits, entries, count=1)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
results = text_search.collate_results(hits)
|
||||||
|
results = sorted(results, key=lambda x: float(x.score))[:1]
|
||||||
# search results should contain "git clone" entry
|
# search results should contain "git clone" entry
|
||||||
search_result = results[0].entry
|
search_result = results[0].entry
|
||||||
assert "git clone" in search_result
|
assert "git clone" in search_result
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
@pytest.mark.django_db
|
||||||
|
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||||
max_tokens = 256
|
max_tokens = 256
|
||||||
|
@ -137,47 +172,46 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# reload embeddings, entries, notes model after adding new org-mode file
|
# reload embeddings, entries, notes model after adding new org-mode file
|
||||||
initial_notes_model = text_search.setup(
|
with caplog.at_level(logging.INFO):
|
||||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify newly added org-mode entry is split by max tokens
|
# verify newly added org-mode entry is split by max tokens
|
||||||
assert len(initial_notes_model.entries) == 2
|
record = caplog.records[1]
|
||||||
assert len(initial_notes_model.corpus_embeddings) == 2
|
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
|
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
|
||||||
|
@pytest.mark.django_db
|
||||||
def test_entry_chunking_by_max_tokens_not_full_corpus(
|
def test_entry_chunking_by_max_tokens_not_full_corpus(
|
||||||
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
|
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||||
data = {
|
data = {
|
||||||
"readme.org": """
|
"readme.org": """
|
||||||
* Khoj
|
* Khoj
|
||||||
/Allow natural language search on user content like notes, images using transformer based models/
|
/Allow natural language search on user content like notes, images using transformer based models/
|
||||||
|
|
||||||
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
|
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
|
||||||
|
|
||||||
** Dependencies
|
** Dependencies
|
||||||
- Python3
|
- Python3
|
||||||
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
|
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
|
||||||
|
|
||||||
** Install
|
** Install
|
||||||
#+begin_src shell
|
#+begin_src shell
|
||||||
git clone https://github.com/khoj-ai/khoj && cd khoj
|
git clone https://github.com/khoj-ai/khoj && cd khoj
|
||||||
conda env create -f environment.yml
|
conda env create -f environment.yml
|
||||||
conda activate khoj
|
conda activate khoj
|
||||||
#+end_src"""
|
#+end_src"""
|
||||||
}
|
}
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
OrgToJsonl,
|
OrgToJsonl,
|
||||||
data,
|
data,
|
||||||
org_config_with_only_new_file,
|
|
||||||
search_models.text_search.bi_encoder,
|
|
||||||
regenerate=False,
|
regenerate=False,
|
||||||
|
user=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_tokens = 256
|
max_tokens = 256
|
||||||
|
@ -191,64 +225,57 @@ def test_entry_chunking_by_max_tokens_not_full_corpus(
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# reload embeddings, entries, notes model after adding new org-mode file
|
# reload embeddings, entries, notes model after adding new org-mode file
|
||||||
initial_notes_model = text_search.setup(
|
with caplog.at_level(logging.INFO):
|
||||||
OrgToJsonl,
|
text_search.setup(
|
||||||
data,
|
OrgToJsonl,
|
||||||
org_config_with_only_new_file,
|
data,
|
||||||
search_models.text_search.bi_encoder,
|
regenerate=False,
|
||||||
regenerate=False,
|
full_corpus=False,
|
||||||
full_corpus=False,
|
user=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
record = caplog.records[1]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify newly added org-mode entry is split by max tokens
|
# verify newly added org-mode entry is split by max tokens
|
||||||
assert len(initial_notes_model.entries) == 5
|
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
|
||||||
assert len(initial_notes_model.corpus_embeddings) == 5
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db
|
||||||
def test_regenerate_index_with_new_entry(
|
def test_regenerate_index_with_new_entry(
|
||||||
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
|
content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
|
||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
data = get_org_files(content_config.org)
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
initial_notes_model = text_search.setup(
|
data = get_org_files(org_config)
|
||||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(initial_notes_model.entries) == 10
|
with caplog.at_level(logging.INFO):
|
||||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
|
|
||||||
|
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||||
|
|
||||||
# append org-mode entry to first org input file in config
|
# append org-mode entry to first org input file in config
|
||||||
content_config.org.input_files = [f"{new_org_file}"]
|
org_config.input_files = [f"{new_org_file}"]
|
||||||
with open(new_org_file, "w") as f:
|
with open(new_org_file, "w") as f:
|
||||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||||
|
|
||||||
data = get_org_files(content_config.org)
|
data = get_org_files(org_config)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||||
regenerated_notes_model = text_search.setup(
|
with caplog.at_level(logging.INFO):
|
||||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(regenerated_notes_model.entries) == 11
|
assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message
|
||||||
assert len(regenerated_notes_model.corpus_embeddings) == 11
|
verify_embeddings(11, default_user)
|
||||||
|
|
||||||
# verify new entry appended to index, without disrupting order or content of existing entries
|
|
||||||
error_details = compare_index(initial_notes_model, regenerated_notes_model)
|
|
||||||
if error_details:
|
|
||||||
pytest.fail(error_details, False)
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
# reset input_files in config to empty list
|
|
||||||
content_config.org.input_files = []
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db
|
||||||
def test_update_index_with_duplicate_entries_in_stable_order(
|
def test_update_index_with_duplicate_entries_in_stable_order(
|
||||||
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
|
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||||
|
@ -262,30 +289,26 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# load embeddings, entries, notes model after adding new org-mode file
|
# load embeddings, entries, notes model after adding new org-mode file
|
||||||
initial_index = text_search.setup(
|
with caplog.at_level(logging.INFO):
|
||||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
)
|
|
||||||
|
|
||||||
data = get_org_files(org_config_with_only_new_file)
|
data = get_org_files(org_config_with_only_new_file)
|
||||||
|
|
||||||
# update embeddings, entries, notes model after adding new org-mode file
|
# update embeddings, entries, notes model after adding new org-mode file
|
||||||
updated_index = text_search.setup(
|
with caplog.at_level(logging.INFO):
|
||||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify only 1 entry added even if there are multiple duplicate entries
|
# verify only 1 entry added even if there are multiple duplicate entries
|
||||||
assert len(initial_index.entries) == len(updated_index.entries) == 1
|
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||||
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1
|
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
|
||||||
|
|
||||||
# verify the same entry is added even when there are multiple duplicate entries
|
verify_embeddings(1, default_user)
|
||||||
error_details = compare_index(initial_index, updated_index)
|
|
||||||
if error_details:
|
|
||||||
pytest.fail(error_details)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
@pytest.mark.django_db
|
||||||
|
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||||
|
|
||||||
|
@ -296,9 +319,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
|
||||||
data = get_org_files(org_config_with_only_new_file)
|
data = get_org_files(org_config_with_only_new_file)
|
||||||
|
|
||||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||||
initial_index = text_search.setup(
|
with caplog.at_level(logging.INFO):
|
||||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
)
|
|
||||||
|
|
||||||
# update embeddings, entries, notes model after removing an entry from the org file
|
# update embeddings, entries, notes model after removing an entry from the org file
|
||||||
with open(new_file_to_index, "w") as f:
|
with open(new_file_to_index, "w") as f:
|
||||||
|
@ -307,87 +329,65 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
|
||||||
data = get_org_files(org_config_with_only_new_file)
|
data = get_org_files(org_config_with_only_new_file)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
updated_index = text_search.setup(
|
with caplog.at_level(logging.INFO):
|
||||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify only 1 entry added even if there are multiple duplicate entries
|
# verify only 1 entry added even if there are multiple duplicate entries
|
||||||
assert len(initial_index.entries) == len(updated_index.entries) + 1
|
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||||
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
|
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message
|
||||||
|
|
||||||
# verify the same entry is added even when there are multiple duplicate entries
|
verify_embeddings(1, default_user)
|
||||||
error_details = compare_index(updated_index, initial_index)
|
|
||||||
if error_details:
|
|
||||||
pytest.fail(error_details)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
@pytest.mark.django_db
|
||||||
|
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
|
||||||
# Arrange
|
# Arrange
|
||||||
data = get_org_files(content_config.org)
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
initial_notes_model = text_search.setup(
|
data = get_org_files(org_config)
|
||||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
|
with caplog.at_level(logging.INFO):
|
||||||
)
|
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||||
|
|
||||||
# append org-mode entry to first org input file in config
|
# append org-mode entry to first org input file in config
|
||||||
with open(new_org_file, "w") as f:
|
with open(new_org_file, "w") as f:
|
||||||
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||||
f.write(new_entry)
|
f.write(new_entry)
|
||||||
|
|
||||||
data = get_org_files(content_config.org)
|
data = get_org_files(org_config)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# update embeddings, entries with the newly added note
|
# update embeddings, entries with the newly added note
|
||||||
content_config.org.input_files = [f"{new_org_file}"]
|
with caplog.at_level(logging.INFO):
|
||||||
final_notes_model = text_search.setup(
|
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1
|
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||||
assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1
|
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
|
||||||
|
|
||||||
# verify new entry appended to index, without disrupting order or content of existing entries
|
verify_embeddings(11, default_user)
|
||||||
error_details = compare_index(initial_notes_model, final_notes_model)
|
|
||||||
if error_details:
|
|
||||||
pytest.fail(error_details, False)
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
# reset input_files in config to empty list
|
|
||||||
content_config.org.input_files = []
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
||||||
def test_text_search_setup_github(content_config: ContentConfig, search_models: SearchModels):
|
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
|
||||||
|
# Arrange
|
||||||
|
github_config = GithubConfig.objects.filter(user=default_user).first()
|
||||||
# Act
|
# Act
|
||||||
# Regenerate github embeddings to test asymmetric setup without caching
|
# Regenerate github embeddings to test asymmetric setup without caching
|
||||||
github_model = text_search.setup(
|
text_search.setup(
|
||||||
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
|
GithubToJsonl,
|
||||||
|
{},
|
||||||
|
regenerate=True,
|
||||||
|
user=default_user,
|
||||||
|
config=github_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(github_model.entries) > 1
|
embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count()
|
||||||
|
assert embeddings > 1
|
||||||
|
|
||||||
|
|
||||||
def compare_index(initial_notes_model, final_notes_model):
|
def verify_embeddings(expected_count, user):
|
||||||
mismatched_entries, mismatched_embeddings = [], []
|
embeddings = Embeddings.objects.filter(user=user, file_type="org").count()
|
||||||
for index in range(len(initial_notes_model.entries)):
|
assert embeddings == expected_count
|
||||||
if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
|
|
||||||
mismatched_entries.append(index)
|
|
||||||
|
|
||||||
# verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
|
|
||||||
for index in range(len(initial_notes_model.corpus_embeddings)):
|
|
||||||
if not initial_notes_model.corpus_embeddings[index].allclose(final_notes_model.corpus_embeddings[index]):
|
|
||||||
mismatched_embeddings.append(index)
|
|
||||||
|
|
||||||
error_details = ""
|
|
||||||
if mismatched_entries:
|
|
||||||
mismatched_entries_str = ",".join(map(str, mismatched_entries))
|
|
||||||
error_details += f"Entries at {mismatched_entries_str} not equal\n"
|
|
||||||
if mismatched_embeddings:
|
|
||||||
mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
|
|
||||||
error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
|
|
||||||
|
|
||||||
return error_details
|
|
||||||
|
|
|
@ -3,68 +3,40 @@ from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
|
|
||||||
|
|
||||||
def test_no_word_filter():
|
|
||||||
# Arrange
|
|
||||||
word_filter = WordFilter()
|
|
||||||
entries = arrange_content()
|
|
||||||
q_with_no_filter = "head tail"
|
|
||||||
|
|
||||||
# Act
|
|
||||||
can_filter = word_filter.can_filter(q_with_no_filter)
|
|
||||||
ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert can_filter == False
|
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {0, 1, 2, 3}
|
|
||||||
|
|
||||||
|
|
||||||
def test_word_exclude_filter():
|
def test_word_exclude_filter():
|
||||||
# Arrange
|
# Arrange
|
||||||
word_filter = WordFilter()
|
word_filter = WordFilter()
|
||||||
entries = arrange_content()
|
|
||||||
q_with_exclude_filter = 'head -"exclude_word" tail'
|
q_with_exclude_filter = 'head -"exclude_word" tail'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = word_filter.can_filter(q_with_exclude_filter)
|
can_filter = word_filter.can_filter(q_with_exclude_filter)
|
||||||
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {0, 2}
|
|
||||||
|
|
||||||
|
|
||||||
def test_word_include_filter():
|
def test_word_include_filter():
|
||||||
# Arrange
|
# Arrange
|
||||||
word_filter = WordFilter()
|
word_filter = WordFilter()
|
||||||
entries = arrange_content()
|
|
||||||
query_with_include_filter = 'head +"include_word" tail'
|
query_with_include_filter = 'head +"include_word" tail'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = word_filter.can_filter(query_with_include_filter)
|
can_filter = word_filter.can_filter(query_with_include_filter)
|
||||||
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {2, 3}
|
|
||||||
|
|
||||||
|
|
||||||
def test_word_include_and_exclude_filter():
|
def test_word_include_and_exclude_filter():
|
||||||
# Arrange
|
# Arrange
|
||||||
word_filter = WordFilter()
|
word_filter = WordFilter()
|
||||||
entries = arrange_content()
|
|
||||||
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
|
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
|
||||||
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
|
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
assert ret_query == "head tail"
|
|
||||||
assert entry_indices == {2}
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_word_filter_terms():
|
def test_get_word_filter_terms():
|
||||||
|
|
Loading…
Reference in a new issue