@@ -155,9 +128,8 @@
inputFilter = null;
}
- var compressed_jsonl = document.getElementById("compressed-jsonl").value;
- var embeddings_file = document.getElementById("embeddings-file").value;
- var index_heading_entries = document.getElementById("index-heading-entries").value;
+ // var index_heading_entries = document.getElementById("index-heading-entries").value;
+ var index_heading_entries = true;
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/content_type/{{ content_type }}', {
@@ -169,8 +141,6 @@
body: JSON.stringify({
"input_files": inputFiles,
"input_filter": inputFilter,
- "compressed_jsonl": compressed_jsonl,
- "embeddings_file": embeddings_file,
"index_heading_entries": index_heading_entries
})
})
diff --git a/src/khoj/interface/web/content_type_notion_input.html b/src/khoj/interface/web/content_type_notion_input.html
index dde5f2df..965c1ef5 100644
--- a/src/khoj/interface/web/content_type_notion_input.html
+++ b/src/khoj/interface/web/content_type_notion_input.html
@@ -20,24 +20,6 @@
-
@@ -51,8 +33,6 @@
submit.addEventListener("click", function(event) {
event.preventDefault();
- const compressed_jsonl = document.getElementById("compressed-jsonl").value;
- const embeddings_file = document.getElementById("embeddings-file").value;
const token = document.getElementById("token").value;
if (token == "") {
@@ -70,8 +50,6 @@
},
body: JSON.stringify({
"token": token,
- "compressed_jsonl": compressed_jsonl,
- "embeddings_file": embeddings_file,
})
})
.then(response => response.json())
diff --git a/src/khoj/interface/web/index.html b/src/khoj/interface/web/index.html
index 581ed9b8..ccf1ca71 100644
--- a/src/khoj/interface/web/index.html
+++ b/src/khoj/interface/web/index.html
@@ -172,7 +172,7 @@
url = createRequestUrl(query, type, results_count || 5, rerank);
fetch(url, {
headers: {
- "X-CSRFToken": csrfToken
+ "Content-Type": "application/json"
}
})
.then(response => response.json())
@@ -199,8 +199,8 @@
fetch("/api/config/types")
.then(response => response.json())
.then(enabled_types => {
- // Show warning if no content types are enabled
- if (enabled_types.detail) {
+ // Show warning if no content types are enabled, or just one ("all")
+ if (enabled_types[0] === "all" && enabled_types.length === 1) {
document.getElementById("results").innerHTML = "
To use Khoj search, setup your content plugins on the Khoj
settings page.
";
document.getElementById("query").setAttribute("disabled", "disabled");
document.getElementById("query").setAttribute("placeholder", "Configure Khoj to enable search");
diff --git a/src/app/main.py b/src/khoj/main.py
similarity index 88%
rename from src/app/main.py
rename to src/khoj/main.py
index 16f7cced..a713cc97 100644
--- a/src/app/main.py
+++ b/src/khoj/main.py
@@ -24,10 +24,15 @@ from rich.logging import RichHandler
from django.core.asgi import get_asgi_application
from django.core.management import call_command
-# Internal Packages
-from khoj.configure import configure_routes, initialize_server, configure_middleware
-from khoj.utils import state
-from khoj.utils.cli import cli
+# Initialize Django
+os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
+django.setup()
+
+# Initialize Django Database
+call_command("migrate", "--noinput")
+
+# Initialize Django Static Files
+call_command("collectstatic", "--noinput")
# Initialize Django
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings")
@@ -54,6 +59,11 @@ app.add_middleware(
# Set Locale
locale.setlocale(locale.LC_ALL, "")
+# Internal Packages. We do this after setting up Django so that Django features are accessible to the app.
+from khoj.configure import configure_routes, initialize_server, configure_middleware
+from khoj.utils import state
+from khoj.utils.cli import cli
+
# Setup Logger
rich_handler = RichHandler(rich_tracebacks=True)
rich_handler.setFormatter(fmt=logging.Formatter(fmt="%(message)s", datefmt="[%X]"))
@@ -95,6 +105,8 @@ def run():
# Mount Django and Static Files
app.mount("/django", django_app, name="django")
+ if not os.path.exists("static"):
+ os.mkdir("static")
app.mount("/static", StaticFiles(directory="static"), name="static")
# Configure Middleware
@@ -111,6 +123,7 @@ def set_state(args):
state.host = args.host
state.port = args.port
state.demo = args.demo
+ state.anonymous_mode = args.anonymous_mode
state.khoj_version = version("khoj-assistant")
diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py
new file mode 100644
index 00000000..f0e2df77
--- /dev/null
+++ b/src/khoj/processor/embeddings.py
@@ -0,0 +1,57 @@
+from typing import List
+
+import torch
+from langchain.embeddings import HuggingFaceEmbeddings
+from sentence_transformers import CrossEncoder
+
+from khoj.utils.rawconfig import SearchResponse
+
+
+class EmbeddingsModel:
+ def __init__(self):
+ self.model_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
+ encode_kwargs = {"normalize_embeddings": True}
+ # encode_kwargs = {}
+
+ if torch.cuda.is_available():
+ # Use CUDA GPU
+ device = torch.device("cuda:0")
+ elif torch.backends.mps.is_available():
+ # Use Apple M1 Metal Acceleration
+ device = torch.device("mps")
+ else:
+ device = torch.device("cpu")
+
+ model_kwargs = {"device": device}
+ self.embeddings_model = HuggingFaceEmbeddings(
+ model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
+ )
+
+ def embed_query(self, query):
+ return self.embeddings_model.embed_query(query)
+
+ def embed_documents(self, docs):
+ return self.embeddings_model.embed_documents(docs)
+
+
+class CrossEncoderModel:
+ def __init__(self):
+ self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
+
+ if torch.cuda.is_available():
+ # Use CUDA GPU
+ device = torch.device("cuda:0")
+
+ elif torch.backends.mps.is_available():
+ # Use Apple M1 Metal Acceleration
+ device = torch.device("mps")
+
+ else:
+ device = torch.device("cpu")
+
+ self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=device)
+
+ def predict(self, query, hits: List[SearchResponse]):
+ cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
+ cross_scores = self.cross_encoder_model.predict(cross__inp)
+ return cross_scores
diff --git a/src/khoj/processor/github/github_to_jsonl.py b/src/khoj/processor/github/github_to_jsonl.py
index bcd2e530..8feb6a31 100644
--- a/src/khoj/processor/github/github_to_jsonl.py
+++ b/src/khoj/processor/github/github_to_jsonl.py
@@ -2,7 +2,7 @@
import logging
import time
from datetime import datetime
-from typing import Dict, List, Union
+from typing import Dict, List, Union, Tuple
# External Packages
import requests
@@ -12,18 +12,31 @@ from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
-from khoj.processor.text_to_jsonl import TextToJsonl
-from khoj.utils.jsonl import compress_jsonl_data
+from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.rawconfig import Entry
+from database.models import Embeddings, GithubConfig, KhojUser
logger = logging.getLogger(__name__)
-class GithubToJsonl(TextToJsonl):
- def __init__(self, config: GithubContentConfig):
+class GithubToJsonl(TextEmbeddings):
+ def __init__(self, config: GithubConfig):
super().__init__(config)
- self.config = config
+ raw_repos = config.githubrepoconfig.all()
+ repos = []
+ for repo in raw_repos:
+ repos.append(
+ GithubRepoConfig(
+ name=repo.name,
+ owner=repo.owner,
+ branch=repo.branch,
+ )
+ )
+ self.config = GithubContentConfig(
+ pat_token=config.pat_token,
+ repos=repos,
+ )
self.session = requests.Session()
self.session.headers.update({"Authorization": f"token {self.config.pat_token}"})
@@ -37,7 +50,9 @@ class GithubToJsonl(TextToJsonl):
else:
return
- def process(self, previous_entries=[], files=None, full_corpus=True):
+ def process(
+ self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
+ ) -> Tuple[int, int]:
if self.config.pat_token is None or self.config.pat_token == "":
logger.error(f"Github PAT token is not set. Skipping github content")
raise ValueError("Github PAT token is not set. Skipping github content")
@@ -45,7 +60,7 @@ class GithubToJsonl(TextToJsonl):
for repo in self.config.repos:
current_entries += self.process_repo(repo)
- return self.update_entries_with_ids(current_entries, previous_entries)
+ return self.update_entries_with_ids(current_entries, user=user)
def process_repo(self, repo: GithubRepoConfig):
repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}"
@@ -80,26 +95,18 @@ class GithubToJsonl(TextToJsonl):
current_entries += issue_entries
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
- current_entries = TextToJsonl.split_entries_by_max_tokens(current_entries, max_tokens=256)
+ current_entries = TextEmbeddings.split_entries_by_max_tokens(current_entries, max_tokens=256)
return current_entries
- def update_entries_with_ids(self, current_entries, previous_entries):
+ def update_entries_with_ids(self, current_entries, user: KhojUser = None):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
- entries_with_ids = TextToJsonl.mark_entries_for_update(
- current_entries, previous_entries, key="compiled", logger=logger
+ num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
+ current_entries, Embeddings.EmbeddingsType.GITHUB, key="compiled", logger=logger, user=user
)
- with timer("Write github entries to JSONL file", logger):
- # Process Each Entry from All Notes Files
- entries = list(map(lambda entry: entry[1], entries_with_ids))
- jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
-
- # Compress JSONL formatted Data
- compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
-
- return entries_with_ids
+ return num_new_embeddings, num_deleted_embeddings
def get_files(self, repo_url: str, repo: GithubRepoConfig):
# Get the contents of the repository
diff --git a/src/khoj/processor/jsonl/__init__.py b/src/khoj/processor/jsonl/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/src/khoj/processor/jsonl/jsonl_to_jsonl.py b/src/khoj/processor/jsonl/jsonl_to_jsonl.py
deleted file mode 100644
index 4a6fab99..00000000
--- a/src/khoj/processor/jsonl/jsonl_to_jsonl.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# Standard Packages
-import glob
-import logging
-from pathlib import Path
-from typing import List
-
-# Internal Packages
-from khoj.processor.text_to_jsonl import TextToJsonl
-from khoj.utils.helpers import get_absolute_path, timer
-from khoj.utils.jsonl import load_jsonl, compress_jsonl_data
-from khoj.utils.rawconfig import Entry
-
-
-logger = logging.getLogger(__name__)
-
-
-class JsonlToJsonl(TextToJsonl):
- # Define Functions
- def process(self, previous_entries=[], files: dict[str, str] = {}, full_corpus: bool = True):
- # Extract required fields from config
- input_jsonl_files, input_jsonl_filter, output_file = (
- self.config.input_files,
- self.config.input_filter,
- self.config.compressed_jsonl,
- )
-
- # Get Jsonl Input Files to Process
- all_input_jsonl_files = JsonlToJsonl.get_jsonl_files(input_jsonl_files, input_jsonl_filter)
-
- # Extract Entries from specified jsonl files
- with timer("Parse entries from jsonl files", logger):
- input_jsons = JsonlToJsonl.extract_jsonl_entries(all_input_jsonl_files)
- current_entries = list(map(Entry.from_dict, input_jsons))
-
- # Split entries by max tokens supported by model
- with timer("Split entries by max token size supported by model", logger):
- current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
-
- # Identify, mark and merge any new entries with previous entries
- with timer("Identify new or updated entries", logger):
- entries_with_ids = TextToJsonl.mark_entries_for_update(
- current_entries, previous_entries, key="compiled", logger=logger
- )
-
- with timer("Write entries to JSONL file", logger):
- # Process Each Entry from All Notes Files
- entries = list(map(lambda entry: entry[1], entries_with_ids))
- jsonl_data = JsonlToJsonl.convert_entries_to_jsonl(entries)
-
- # Compress JSONL formatted Data
- compress_jsonl_data(jsonl_data, output_file)
-
- return entries_with_ids
-
- @staticmethod
- def get_jsonl_files(jsonl_files=None, jsonl_file_filters=None):
- "Get all jsonl files to process"
- absolute_jsonl_files, filtered_jsonl_files = set(), set()
- if jsonl_files:
- absolute_jsonl_files = {get_absolute_path(jsonl_file) for jsonl_file in jsonl_files}
- if jsonl_file_filters:
- filtered_jsonl_files = {
- filtered_file
- for jsonl_file_filter in jsonl_file_filters
- for filtered_file in glob.glob(get_absolute_path(jsonl_file_filter), recursive=True)
- }
-
- all_jsonl_files = sorted(absolute_jsonl_files | filtered_jsonl_files)
-
- files_with_non_jsonl_extensions = {
- jsonl_file for jsonl_file in all_jsonl_files if not jsonl_file.endswith(".jsonl")
- }
- if any(files_with_non_jsonl_extensions):
- print(f"[Warning] There maybe non jsonl files in the input set: {files_with_non_jsonl_extensions}")
-
- logger.debug(f"Processing files: {all_jsonl_files}")
-
- return all_jsonl_files
-
- @staticmethod
- def extract_jsonl_entries(jsonl_files):
- "Extract entries from specified jsonl files"
- entries = []
- for jsonl_file in jsonl_files:
- entries.extend(load_jsonl(Path(jsonl_file)))
- return entries
-
- @staticmethod
- def convert_entries_to_jsonl(entries: List[Entry]):
- "Convert each entry to JSON and collate as JSONL"
- return "".join([f"{entry.to_json()}\n" for entry in entries])
diff --git a/src/khoj/processor/markdown/markdown_to_jsonl.py b/src/khoj/processor/markdown/markdown_to_jsonl.py
index c2f0f0bf..17136b00 100644
--- a/src/khoj/processor/markdown/markdown_to_jsonl.py
+++ b/src/khoj/processor/markdown/markdown_to_jsonl.py
@@ -3,29 +3,28 @@ import logging
import re
import urllib3
from pathlib import Path
-from typing import List
+from typing import Tuple, List
# Internal Packages
-from khoj.processor.text_to_jsonl import TextToJsonl
+from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
from khoj.utils.constants import empty_escape_sequences
-from khoj.utils.jsonl import compress_jsonl_data
-from khoj.utils.rawconfig import Entry, TextContentConfig
+from khoj.utils.rawconfig import Entry
+from database.models import Embeddings, KhojUser
logger = logging.getLogger(__name__)
-class MarkdownToJsonl(TextToJsonl):
- def __init__(self, config: TextContentConfig):
- super().__init__(config)
- self.config = config
+class MarkdownToJsonl(TextEmbeddings):
+ def __init__(self):
+ super().__init__()
# Define Functions
- def process(self, previous_entries=[], files=None, full_corpus: bool = True):
+ def process(
+ self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
+ ) -> Tuple[int, int]:
# Extract required fields from config
- output_file = self.config.compressed_jsonl
-
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
@@ -45,19 +44,17 @@ class MarkdownToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
- entries_with_ids = TextToJsonl.mark_entries_for_update(
- current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
+ num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
+ current_entries,
+ Embeddings.EmbeddingsType.MARKDOWN,
+ "compiled",
+ logger,
+ deletion_file_names,
+ user,
+ regenerate=regenerate,
)
- with timer("Write markdown entries to JSONL file", logger):
- # Process Each Entry from All Notes Files
- entries = list(map(lambda entry: entry[1], entries_with_ids))
- jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
-
- # Compress JSONL formatted Data
- compress_jsonl_data(jsonl_data, output_file)
-
- return entries_with_ids
+ return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_markdown_entries(markdown_files):
diff --git a/src/khoj/processor/notion/notion_to_jsonl.py b/src/khoj/processor/notion/notion_to_jsonl.py
index 0df56c37..0081350a 100644
--- a/src/khoj/processor/notion/notion_to_jsonl.py
+++ b/src/khoj/processor/notion/notion_to_jsonl.py
@@ -1,5 +1,6 @@
# Standard Packages
import logging
+from typing import Tuple
# External Packages
import requests
@@ -7,9 +8,9 @@ import requests
# Internal Packages
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, NotionContentConfig
-from khoj.processor.text_to_jsonl import TextToJsonl
-from khoj.utils.jsonl import compress_jsonl_data
+from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.rawconfig import Entry
+from database.models import Embeddings, KhojUser, NotionConfig
from enum import Enum
@@ -49,10 +50,12 @@ class NotionBlockType(Enum):
CALLOUT = "callout"
-class NotionToJsonl(TextToJsonl):
- def __init__(self, config: NotionContentConfig):
+class NotionToJsonl(TextEmbeddings):
+ def __init__(self, config: NotionConfig):
super().__init__(config)
- self.config = config
+ self.config = NotionContentConfig(
+ token=config.token,
+ )
self.session = requests.Session()
self.session.headers.update({"Authorization": f"Bearer {config.token}", "Notion-Version": "2022-02-22"})
self.unsupported_block_types = [
@@ -80,7 +83,9 @@ class NotionToJsonl(TextToJsonl):
self.body_params = {"page_size": 100}
- def process(self, previous_entries=[], files=None, full_corpus=True):
+ def process(
+ self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
+ ) -> Tuple[int, int]:
current_entries = []
# Get all pages
@@ -112,7 +117,7 @@ class NotionToJsonl(TextToJsonl):
page_entries = self.process_page(p_or_d)
current_entries.extend(page_entries)
- return self.update_entries_with_ids(current_entries, previous_entries)
+ return self.update_entries_with_ids(current_entries, user)
def process_page(self, page):
page_id = page["id"]
@@ -241,19 +246,11 @@ class NotionToJsonl(TextToJsonl):
title = None
return title, content
- def update_entries_with_ids(self, current_entries, previous_entries):
+ def update_entries_with_ids(self, current_entries, user: KhojUser = None):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
- entries_with_ids = TextToJsonl.mark_entries_for_update(
- current_entries, previous_entries, key="compiled", logger=logger
+ num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
+ current_entries, Embeddings.EmbeddingsType.NOTION, key="compiled", logger=logger, user=user
)
- with timer("Write Notion entries to JSONL file", logger):
- # Process Each Entry from all Notion entries
- entries = list(map(lambda entry: entry[1], entries_with_ids))
- jsonl_data = TextToJsonl.convert_text_maps_to_jsonl(entries)
-
- # Compress JSONL formatted Data
- compress_jsonl_data(jsonl_data, self.config.compressed_jsonl)
-
- return entries_with_ids
+ return num_new_embeddings, num_deleted_embeddings
diff --git a/src/khoj/processor/org_mode/org_to_jsonl.py b/src/khoj/processor/org_mode/org_to_jsonl.py
index 2f22add4..90fdc029 100644
--- a/src/khoj/processor/org_mode/org_to_jsonl.py
+++ b/src/khoj/processor/org_mode/org_to_jsonl.py
@@ -5,28 +5,26 @@ from typing import Iterable, List, Tuple
# Internal Packages
from khoj.processor.org_mode import orgnode
-from khoj.processor.text_to_jsonl import TextToJsonl
+from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
-from khoj.utils.jsonl import compress_jsonl_data
-from khoj.utils.rawconfig import Entry, TextContentConfig
+from khoj.utils.rawconfig import Entry
from khoj.utils import state
+from database.models import Embeddings, KhojUser
logger = logging.getLogger(__name__)
-class OrgToJsonl(TextToJsonl):
- def __init__(self, config: TextContentConfig):
- super().__init__(config)
- self.config = config
+class OrgToJsonl(TextEmbeddings):
+ def __init__(self):
+ super().__init__()
# Define Functions
def process(
- self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True
- ) -> List[Tuple[int, Entry]]:
+ self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
+ ) -> Tuple[int, int]:
# Extract required fields from config
- output_file = self.config.compressed_jsonl
- index_heading_entries = self.config.index_heading_entries
+ index_heading_entries = True
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
@@ -47,19 +45,17 @@ class OrgToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
- entries_with_ids = TextToJsonl.mark_entries_for_update(
- current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
+ num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
+ current_entries,
+ Embeddings.EmbeddingsType.ORG,
+ "compiled",
+ logger,
+ deletion_file_names,
+ user,
+ regenerate=regenerate,
)
- # Process Each Entry from All Notes Files
- with timer("Write org entries to JSONL file", logger):
- entries = map(lambda entry: entry[1], entries_with_ids)
- jsonl_data = self.convert_org_entries_to_jsonl(entries)
-
- # Compress JSONL formatted Data
- compress_jsonl_data(jsonl_data, output_file)
-
- return entries_with_ids
+ return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_org_entries(org_files: dict[str, str]):
diff --git a/src/khoj/processor/pdf/pdf_to_jsonl.py b/src/khoj/processor/pdf/pdf_to_jsonl.py
index c24d9940..3a712c68 100644
--- a/src/khoj/processor/pdf/pdf_to_jsonl.py
+++ b/src/khoj/processor/pdf/pdf_to_jsonl.py
@@ -1,28 +1,31 @@
# Standard Packages
import os
import logging
-from typing import List
+from typing import List, Tuple
import base64
# External Packages
from langchain.document_loaders import PyMuPDFLoader
# Internal Packages
-from khoj.processor.text_to_jsonl import TextToJsonl
+from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
-from khoj.utils.jsonl import compress_jsonl_data
from khoj.utils.rawconfig import Entry
+from database.models import Embeddings, KhojUser
logger = logging.getLogger(__name__)
-class PdfToJsonl(TextToJsonl):
- # Define Functions
- def process(self, previous_entries=[], files: dict[str, str] = None, full_corpus: bool = True):
- # Extract required fields from config
- output_file = self.config.compressed_jsonl
+class PdfToJsonl(TextEmbeddings):
+ def __init__(self):
+ super().__init__()
+ # Define Functions
+ def process(
+ self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
+ ) -> Tuple[int, int]:
+ # Extract required fields from config
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
@@ -40,19 +43,17 @@ class PdfToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
- entries_with_ids = TextToJsonl.mark_entries_for_update(
- current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
+ num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
+ current_entries,
+ Embeddings.EmbeddingsType.PDF,
+ "compiled",
+ logger,
+ deletion_file_names,
+ user,
+ regenerate=regenerate,
)
- with timer("Write PDF entries to JSONL file", logger):
- # Process Each Entry from All Notes Files
- entries = list(map(lambda entry: entry[1], entries_with_ids))
- jsonl_data = PdfToJsonl.convert_pdf_maps_to_jsonl(entries)
-
- # Compress JSONL formatted Data
- compress_jsonl_data(jsonl_data, output_file)
-
- return entries_with_ids
+ return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_pdf_entries(pdf_files):
@@ -62,7 +63,7 @@ class PdfToJsonl(TextToJsonl):
entry_to_location_map = []
for pdf_file in pdf_files:
try:
- # Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PyPDFLoader expects a file path
+ # Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
tmp_file = f"tmp_pdf_file.pdf"
with open(f"{tmp_file}", "wb") as f:
bytes = pdf_files[pdf_file]
diff --git a/src/khoj/processor/plaintext/plaintext_to_jsonl.py b/src/khoj/processor/plaintext/plaintext_to_jsonl.py
index 3acb656e..965a5a7b 100644
--- a/src/khoj/processor/plaintext/plaintext_to_jsonl.py
+++ b/src/khoj/processor/plaintext/plaintext_to_jsonl.py
@@ -4,22 +4,23 @@ from pathlib import Path
from typing import List, Tuple
# Internal Packages
-from khoj.processor.text_to_jsonl import TextToJsonl
+from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import timer
-from khoj.utils.jsonl import compress_jsonl_data
-from khoj.utils.rawconfig import Entry
+from khoj.utils.rawconfig import Entry, TextContentConfig
+from database.models import Embeddings, KhojUser, LocalPlaintextConfig
logger = logging.getLogger(__name__)
-class PlaintextToJsonl(TextToJsonl):
+class PlaintextToJsonl(TextEmbeddings):
+ def __init__(self):
+ super().__init__()
+
# Define Functions
def process(
- self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True
- ) -> List[Tuple[int, Entry]]:
- output_file = self.config.compressed_jsonl
-
+ self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
+ ) -> Tuple[int, int]:
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])
files_to_process = set(files) - deletion_file_names
@@ -37,19 +38,17 @@ class PlaintextToJsonl(TextToJsonl):
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
- entries_with_ids = TextToJsonl.mark_entries_for_update(
- current_entries, previous_entries, key="compiled", logger=logger, deletion_filenames=deletion_file_names
+ num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
+ current_entries,
+ Embeddings.EmbeddingsType.PLAINTEXT,
+ key="compiled",
+ logger=logger,
+ deletion_filenames=deletion_file_names,
+ user=user,
+ regenerate=regenerate,
)
- with timer("Write entries to JSONL file", logger):
- # Process Each Entry from All Notes Files
- entries = list(map(lambda entry: entry[1], entries_with_ids))
- plaintext_data = PlaintextToJsonl.convert_entries_to_jsonl(entries)
-
- # Compress JSONL formatted Data
- compress_jsonl_data(plaintext_data, output_file)
-
- return entries_with_ids
+ return num_new_embeddings, num_deleted_embeddings
@staticmethod
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
diff --git a/src/khoj/processor/text_to_jsonl.py b/src/khoj/processor/text_to_jsonl.py
index 98f5986f..c83c83b1 100644
--- a/src/khoj/processor/text_to_jsonl.py
+++ b/src/khoj/processor/text_to_jsonl.py
@@ -2,24 +2,33 @@
from abc import ABC, abstractmethod
import hashlib
import logging
-from typing import Callable, List, Tuple, Set
+import uuid
+from tqdm import tqdm
+from typing import Callable, List, Tuple, Set, Any
from khoj.utils.helpers import timer
+
# Internal Packages
-from khoj.utils.rawconfig import Entry, TextConfigBase
+from khoj.utils.rawconfig import Entry
+from khoj.processor.embeddings import EmbeddingsModel
+from khoj.search_filter.date_filter import DateFilter
+from database.models import KhojUser, Embeddings, EmbeddingsDates
+from database.adapters import EmbeddingsAdapters
logger = logging.getLogger(__name__)
-class TextToJsonl(ABC):
- def __init__(self, config: TextConfigBase):
+class TextEmbeddings(ABC):
+ def __init__(self, config: Any = None):
+ self.embeddings_model = EmbeddingsModel()
self.config = config
+ self.date_filter = DateFilter()
@abstractmethod
def process(
- self, previous_entries: List[Entry] = [], files: dict[str, str] = None, full_corpus: bool = True
- ) -> List[Tuple[int, Entry]]:
+ self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
+ ) -> Tuple[int, int]:
...
@staticmethod
@@ -38,6 +47,7 @@ class TextToJsonl(ABC):
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
+ corpus_id = uuid.uuid4()
# Split entry into chunks of max tokens
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
@@ -57,11 +67,103 @@ class TextToJsonl(ABC):
raw=entry.raw,
heading=entry.heading,
file=entry.file,
+ corpus_id=corpus_id,
)
)
return chunked_entries
+ def update_embeddings(
+ self,
+ current_entries: List[Entry],
+ file_type: str,
+ key="compiled",
+ logger: logging.Logger = None,
+ deletion_filenames: Set[str] = None,
+ user: KhojUser = None,
+ regenerate: bool = False,
+ ):
+ with timer("Construct current entry hashes", logger):
+ hashes_by_file = dict[str, set[str]]()
+ current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries))
+ hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
+ for entry in tqdm(current_entries, desc="Hashing Entries"):
+ hashes_by_file.setdefault(entry.file, set()).add(TextEmbeddings.hash_func(key)(entry))
+
+ num_deleted_embeddings = 0
+ with timer("Preparing dataset for regeneration", logger):
+ if regenerate:
+ logger.info(f"Deleting all embeddings for file type {file_type}")
+ num_deleted_embeddings = EmbeddingsAdapters.delete_all_embeddings(user, file_type)
+
+ num_new_embeddings = 0
+ with timer("Identify hashes for adding new entries", logger):
+ for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
+ hashes_for_file = hashes_by_file[file]
+ hashes_to_process = set()
+ existing_entries = Embeddings.objects.filter(
+ user=user, hashed_value__in=hashes_for_file, file_type=file_type
+ )
+ existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
+ hashes_to_process = hashes_for_file - existing_entry_hashes
+ # for hashed_val in hashes_for_file:
+ # if not EmbeddingsAdapters.does_embedding_exist(user, hashed_val):
+ # hashes_to_process.add(hashed_val)
+
+ entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
+ data_to_embed = [getattr(entry, key) for entry in entries_to_process]
+ embeddings = self.embeddings_model.embed_documents(data_to_embed)
+
+ with timer("Update the database with new vector embeddings", logger):
+ embeddings_to_create = []
+ for hashed_val, embedding in zip(hashes_to_process, embeddings):
+ entry = hash_to_current_entries[hashed_val]
+ embeddings_to_create.append(
+ Embeddings(
+ user=user,
+ embeddings=embedding,
+ raw=entry.raw,
+ compiled=entry.compiled,
+ heading=entry.heading,
+ file_path=entry.file,
+ file_type=file_type,
+ hashed_value=hashed_val,
+ corpus_id=entry.corpus_id,
+ )
+ )
+ new_embeddings = Embeddings.objects.bulk_create(embeddings_to_create)
+ num_new_embeddings += len(new_embeddings)
+
+ dates_to_create = []
+ with timer("Create new date associations for new embeddings", logger):
+ for embedding in new_embeddings:
+ dates = self.date_filter.extract_dates(embedding.raw)
+ for date in dates:
+ dates_to_create.append(
+ EmbeddingsDates(
+ date=date,
+ embeddings=embedding,
+ )
+ )
+ new_dates = EmbeddingsDates.objects.bulk_create(dates_to_create)
+ if len(new_dates) > 0:
+ logger.info(f"Created {len(new_dates)} new date entries")
+
+ with timer("Identify hashes for removed entries", logger):
+ for file in hashes_by_file:
+ existing_entry_hashes = EmbeddingsAdapters.get_existing_entry_hashes_by_file(user, file)
+ to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
+ num_deleted_embeddings += len(to_delete_entry_hashes)
+ EmbeddingsAdapters.delete_embedding_by_hash(user, hashed_values=list(to_delete_entry_hashes))
+
+ with timer("Identify hashes for deleting entries", logger):
+ if deletion_filenames is not None:
+ for file_path in deletion_filenames:
+ deleted_count = EmbeddingsAdapters.delete_embedding_by_file(user, file_path)
+ num_deleted_embeddings += deleted_count
+
+ return num_new_embeddings, num_deleted_embeddings
+
@staticmethod
def mark_entries_for_update(
current_entries: List[Entry],
@@ -72,11 +174,11 @@ class TextToJsonl(ABC):
):
# Hash all current and previous entries to identify new entries
with timer("Hash previous, current entries", logger):
- current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
- previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries))
+ current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries))
+ previous_entry_hashes = list(map(TextEmbeddings.hash_func(key), previous_entries))
if deletion_filenames is not None:
deletion_entries = [entry for entry in previous_entries if entry.file in deletion_filenames]
- deletion_entry_hashes = list(map(TextToJsonl.hash_func(key), deletion_entries))
+ deletion_entry_hashes = list(map(TextEmbeddings.hash_func(key), deletion_entries))
else:
deletion_entry_hashes = []
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 345429e8..7c3e3392 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -2,14 +2,15 @@
import concurrent.futures
import math
import time
-import yaml
import logging
import json
from typing import List, Optional, Union, Any
+import asyncio
# External Packages
-from fastapi import APIRouter, HTTPException, Header, Request
-from sentence_transformers import util
+from fastapi import APIRouter, HTTPException, Header, Request, Depends
+from starlette.authentication import requires
+from asgiref.sync import sync_to_async
# Internal Packages
from khoj.configure import configure_processor, configure_server
@@ -20,7 +21,6 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.utils.config import TextSearchModel
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
from khoj.utils.rawconfig import (
- ContentConfig,
FullConfig,
ProcessorConfig,
SearchConfig,
@@ -48,11 +48,74 @@ from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
from fastapi.requests import Request
+from database import adapters
+from database.adapters import EmbeddingsAdapters
+from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
+
# Initialize Router
api = APIRouter()
logger = logging.getLogger(__name__)
+
+def map_config_to_object(content_type: str):
+ if content_type == "org":
+ return LocalOrgConfig
+ if content_type == "markdown":
+ return LocalMarkdownConfig
+ if content_type == "pdf":
+ return LocalPdfConfig
+ if content_type == "plaintext":
+ return LocalPlaintextConfig
+
+
+async def map_config_to_db(config: FullConfig, user: KhojUser):
+ if config.content_type:
+ if config.content_type.org:
+ await LocalOrgConfig.objects.filter(user=user).adelete()
+ await LocalOrgConfig.objects.acreate(
+ input_files=config.content_type.org.input_files,
+ input_filter=config.content_type.org.input_filter,
+ index_heading_entries=config.content_type.org.index_heading_entries,
+ user=user,
+ )
+ if config.content_type.markdown:
+ await LocalMarkdownConfig.objects.filter(user=user).adelete()
+ await LocalMarkdownConfig.objects.acreate(
+ input_files=config.content_type.markdown.input_files,
+ input_filter=config.content_type.markdown.input_filter,
+ index_heading_entries=config.content_type.markdown.index_heading_entries,
+ user=user,
+ )
+ if config.content_type.pdf:
+ await LocalPdfConfig.objects.filter(user=user).adelete()
+ await LocalPdfConfig.objects.acreate(
+ input_files=config.content_type.pdf.input_files,
+ input_filter=config.content_type.pdf.input_filter,
+ index_heading_entries=config.content_type.pdf.index_heading_entries,
+ user=user,
+ )
+ if config.content_type.plaintext:
+ await LocalPlaintextConfig.objects.filter(user=user).adelete()
+ await LocalPlaintextConfig.objects.acreate(
+ input_files=config.content_type.plaintext.input_files,
+ input_filter=config.content_type.plaintext.input_filter,
+ index_heading_entries=config.content_type.plaintext.index_heading_entries,
+ user=user,
+ )
+ if config.content_type.github:
+ await adapters.set_user_github_config(
+ user=user,
+ pat_token=config.content_type.github.pat_token,
+ repos=config.content_type.github.repos,
+ )
+ if config.content_type.notion:
+ await adapters.set_notion_config(
+ user=user,
+ token=config.content_type.notion.token,
+ )
+
+
# If it's a demo instance, prevent updating any of the configuration.
if not state.demo:
@@ -64,7 +127,10 @@ if not state.demo:
state.processor_config = configure_processor(state.config.processor)
@api.get("/config/data", response_model=FullConfig)
- def get_config_data():
+ def get_config_data(request: Request):
+ user = request.user.object if request.user.is_authenticated else None
+ enabled_content = EmbeddingsAdapters.get_unique_file_types(user)
+
return state.config
@api.post("/config/data")
@@ -73,20 +139,19 @@ if not state.demo:
updated_config: FullConfig,
client: Optional[str] = None,
):
- state.config = updated_config
- with open(state.config_file, "w") as outfile:
- yaml.dump(yaml.safe_load(state.config.json(by_alias=True)), outfile)
- outfile.close()
+ user = request.user.object if request.user.is_authenticated else None
+ await map_config_to_db(updated_config, user)
- configuration_update_metadata = dict()
+ configuration_update_metadata = {}
+
+ enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
if state.config.content_type is not None:
- configuration_update_metadata["github"] = state.config.content_type.github is not None
- configuration_update_metadata["notion"] = state.config.content_type.notion is not None
- configuration_update_metadata["org"] = state.config.content_type.org is not None
- configuration_update_metadata["pdf"] = state.config.content_type.pdf is not None
- configuration_update_metadata["markdown"] = state.config.content_type.markdown is not None
- configuration_update_metadata["plugins"] = state.config.content_type.plugins is not None
+ configuration_update_metadata["github"] = "github" in enabled_content
+ configuration_update_metadata["notion"] = "notion" in enabled_content
+ configuration_update_metadata["org"] = "org" in enabled_content
+ configuration_update_metadata["pdf"] = "pdf" in enabled_content
+ configuration_update_metadata["markdown"] = "markdown" in enabled_content
if state.config.processor is not None:
configuration_update_metadata["conversation_processor"] = state.config.processor.conversation is not None
@@ -101,6 +166,7 @@ if not state.demo:
return state.config
@api.post("/config/data/content_type/github", status_code=200)
+ @requires("authenticated")
async def set_content_config_github_data(
request: Request,
updated_config: Union[GithubContentConfig, None],
@@ -108,10 +174,13 @@ if not state.demo:
):
_initialize_config()
- if not state.config.content_type:
- state.config.content_type = ContentConfig(**{"github": updated_config})
- else:
- state.config.content_type.github = updated_config
+ user = request.user.object if request.user.is_authenticated else None
+
+ await adapters.set_user_github_config(
+ user=user,
+ pat_token=updated_config.pat_token,
+ repos=updated_config.repos,
+ )
update_telemetry_state(
request=request,
@@ -121,11 +190,7 @@ if not state.demo:
metadata={"content_type": "github"},
)
- try:
- save_config_to_file_updated_state()
- return {"status": "ok"}
- except Exception as e:
- return {"status": "error", "message": str(e)}
+ return {"status": "ok"}
@api.post("/config/data/content_type/notion", status_code=200)
async def set_content_config_notion_data(
@@ -135,10 +200,12 @@ if not state.demo:
):
_initialize_config()
- if not state.config.content_type:
- state.config.content_type = ContentConfig(**{"notion": updated_config})
- else:
- state.config.content_type.notion = updated_config
+ user = request.user.object if request.user.is_authenticated else None
+
+ await adapters.set_notion_config(
+ user=user,
+ token=updated_config.token,
+ )
update_telemetry_state(
request=request,
@@ -148,11 +215,7 @@ if not state.demo:
metadata={"content_type": "notion"},
)
- try:
- save_config_to_file_updated_state()
- return {"status": "ok"}
- except Exception as e:
- return {"status": "error", "message": str(e)}
+ return {"status": "ok"}
@api.post("/delete/config/data/content_type/{content_type}", status_code=200)
async def remove_content_config_data(
@@ -160,8 +223,7 @@ if not state.demo:
content_type: str,
client: Optional[str] = None,
):
- if not state.config or not state.config.content_type:
- return {"status": "ok"}
+ user = request.user.object if request.user.is_authenticated else None
update_telemetry_state(
request=request,
@@ -171,31 +233,13 @@ if not state.demo:
metadata={"content_type": content_type},
)
- if state.config.content_type:
- state.config.content_type[content_type] = None
+ content_object = map_config_to_object(content_type)
+ await content_object.objects.filter(user=user).adelete()
+ await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
- if content_type == "github":
- state.content_index.github = None
- elif content_type == "notion":
- state.content_index.notion = None
- elif content_type == "plugins":
- state.content_index.plugins = None
- elif content_type == "pdf":
- state.content_index.pdf = None
- elif content_type == "markdown":
- state.content_index.markdown = None
- elif content_type == "org":
- state.content_index.org = None
- elif content_type == "plaintext":
- state.content_index.plaintext = None
- else:
- logger.warning(f"Request to delete unknown content type: {content_type} via API")
+ enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
- try:
- save_config_to_file_updated_state()
- return {"status": "ok"}
- except Exception as e:
- return {"status": "error", "message": str(e)}
+ return {"status": "ok"}
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
async def remove_processor_conversation_config_data(
@@ -228,6 +272,7 @@ if not state.demo:
return {"status": "error", "message": str(e)}
@api.post("/config/data/content_type/{content_type}", status_code=200)
+ # @requires("authenticated")
async def set_content_config_data(
request: Request,
content_type: str,
@@ -236,10 +281,10 @@ if not state.demo:
):
_initialize_config()
- if not state.config.content_type:
- state.config.content_type = ContentConfig(**{content_type: updated_config})
- else:
- state.config.content_type[content_type] = updated_config
+ user = request.user.object if request.user.is_authenticated else None
+
+ content_object = map_config_to_object(content_type)
+ await adapters.set_text_content_config(user, content_object, updated_config)
update_telemetry_state(
request=request,
@@ -249,11 +294,7 @@ if not state.demo:
metadata={"content_type": content_type},
)
- try:
- save_config_to_file_updated_state()
- return {"status": "ok"}
- except Exception as e:
- return {"status": "error", "message": str(e)}
+ return {"status": "ok"}
@api.post("/config/data/processor/conversation/openai", status_code=200)
async def set_processor_openai_config_data(
@@ -337,24 +378,23 @@ def get_default_config_data():
@api.get("/config/types", response_model=List[str])
-def get_config_types():
- """Get configured content types"""
- if state.config is None or state.config.content_type is None:
- raise HTTPException(
- status_code=500,
- detail="Content types not configured. Configure at least one content type on server and restart it.",
- )
+def get_config_types(
+ request: Request,
+):
+ user = request.user.object if request.user.is_authenticated else None
+
+ enabled_file_types = EmbeddingsAdapters.get_unique_file_types(user)
+
+ configured_content_types = list(enabled_file_types)
+
+ if state.config and state.config.content_type:
+ for ctype in state.config.content_type.dict(exclude_none=True):
+ configured_content_types.append(ctype)
- configured_content_types = state.config.content_type.dict(exclude_none=True)
return [
search_type.value
for search_type in SearchType
- if (
- search_type.value in configured_content_types
- and getattr(state.content_index, search_type.value) is not None
- )
- or ("plugins" in configured_content_types and search_type.name in configured_content_types["plugins"])
- or search_type == SearchType.All
+ if (search_type.value in configured_content_types) or search_type == SearchType.All
]
@@ -372,6 +412,7 @@ async def search(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
+ user = request.user.object if request.user.is_authenticated else None
start_time = time.time()
# Run validation checks
@@ -390,10 +431,11 @@ async def search(
search_futures: List[concurrent.futures.Future] = []
# return cached results, if available
- query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
- if query_cache_key in state.query_cache:
- logger.debug(f"Return response from query cache")
- return state.query_cache[query_cache_key]
+ if user:
+ query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
+ if query_cache_key in state.query_cache[user.uuid]:
+ logger.debug(f"Return response from query cache")
+ return state.query_cache[user.uuid][query_cache_key]
# Encode query with filter terms removed
defiltered_query = user_query
@@ -407,84 +449,31 @@ async def search(
]
if text_search_models:
with timer("Encoding query took", logger=logger):
- encoded_asymmetric_query = util.normalize_embeddings(
- text_search_models[0].bi_encoder.encode(
- [defiltered_query],
- convert_to_tensor=True,
- device=state.device,
- )
- )
+ encoded_asymmetric_query = state.embeddings_model.embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor:
- if (t == SearchType.Org or t == SearchType.All) and state.content_index.org and state.search_models.text_search:
- # query org-mode notes
- search_futures += [
- executor.submit(
- text_search.query,
- user_query,
- state.search_models.text_search,
- state.content_index.org,
- question_embedding=encoded_asymmetric_query,
- rank_results=r or False,
- score_threshold=score_threshold,
- dedupe=dedupe or True,
- )
- ]
-
- if (
- (t == SearchType.Markdown or t == SearchType.All)
- and state.content_index.markdown
- and state.search_models.text_search
- ):
+ if t in [
+ SearchType.All,
+ SearchType.Org,
+ SearchType.Markdown,
+ SearchType.Github,
+ SearchType.Notion,
+ SearchType.Plaintext,
+ ]:
# query markdown notes
search_futures += [
executor.submit(
text_search.query,
+ user,
user_query,
- state.search_models.text_search,
- state.content_index.markdown,
+ t,
question_embedding=encoded_asymmetric_query,
rank_results=r or False,
score_threshold=score_threshold,
- dedupe=dedupe or True,
)
]
- if (
- (t == SearchType.Github or t == SearchType.All)
- and state.content_index.github
- and state.search_models.text_search
- ):
- # query github issues
- search_futures += [
- executor.submit(
- text_search.query,
- user_query,
- state.search_models.text_search,
- state.content_index.github,
- question_embedding=encoded_asymmetric_query,
- rank_results=r or False,
- score_threshold=score_threshold,
- dedupe=dedupe or True,
- )
- ]
-
- if (t == SearchType.Pdf or t == SearchType.All) and state.content_index.pdf and state.search_models.text_search:
- # query pdf files
- search_futures += [
- executor.submit(
- text_search.query,
- user_query,
- state.search_models.text_search,
- state.content_index.pdf,
- question_embedding=encoded_asymmetric_query,
- rank_results=r or False,
- score_threshold=score_threshold,
- dedupe=dedupe or True,
- )
- ]
-
- if (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
+ elif (t == SearchType.Image) and state.content_index.image and state.search_models.image_search:
# query images
search_futures += [
executor.submit(
@@ -497,70 +486,6 @@ async def search(
)
]
- if (
- (t == SearchType.All or t in SearchType)
- and state.content_index.plugins
- and state.search_models.plugin_search
- ):
- # query specified plugin type
- # Get plugin content, search model for specified search type, or the first one if none specified
- plugin_search = state.search_models.plugin_search.get(t.value) or next(
- iter(state.search_models.plugin_search.values())
- )
- plugin_content = state.content_index.plugins.get(t.value) or next(
- iter(state.content_index.plugins.values())
- )
- search_futures += [
- executor.submit(
- text_search.query,
- user_query,
- plugin_search,
- plugin_content,
- question_embedding=encoded_asymmetric_query,
- rank_results=r or False,
- score_threshold=score_threshold,
- dedupe=dedupe or True,
- )
- ]
-
- if (
- (t == SearchType.Notion or t == SearchType.All)
- and state.content_index.notion
- and state.search_models.text_search
- ):
- # query notion pages
- search_futures += [
- executor.submit(
- text_search.query,
- user_query,
- state.search_models.text_search,
- state.content_index.notion,
- question_embedding=encoded_asymmetric_query,
- rank_results=r or False,
- score_threshold=score_threshold,
- dedupe=dedupe or True,
- )
- ]
-
- if (
- (t == SearchType.Plaintext or t == SearchType.All)
- and state.content_index.plaintext
- and state.search_models.text_search
- ):
- # query plaintext files
- search_futures += [
- executor.submit(
- text_search.query,
- user_query,
- state.search_models.text_search,
- state.content_index.plaintext,
- question_embedding=encoded_asymmetric_query,
- rank_results=r or False,
- score_threshold=score_threshold,
- dedupe=dedupe or True,
- )
- ]
-
# Query across each requested content types in parallel
with timer("Query took", logger):
for search_future in concurrent.futures.as_completed(search_futures):
@@ -576,15 +501,19 @@ async def search(
count=results_count,
)
else:
- hits, entries = await search_future.result()
+ hits = await search_future.result()
# Collate results
- results += text_search.collate_results(hits, entries, results_count)
+ results += text_search.collate_results(hits, dedupe=dedupe)
- # Sort results across all content types and take top results
- results = sorted(results, key=lambda x: float(x.score), reverse=True)[:results_count]
+ if r:
+ results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
+ else:
+ # Sort results across all content types and take top results
+ results = sorted(results, key=lambda x: float(x.score))[:results_count]
# Cache results
- state.query_cache[query_cache_key] = results
+ if user:
+ state.query_cache[user.uuid][query_cache_key] = results
update_telemetry_state(
request=request,
@@ -596,8 +525,6 @@ async def search(
host=host,
)
- state.previous_query = user_query
-
end_time = time.time()
logger.debug(f"🔍 Search took: {end_time - start_time:.3f} seconds")
@@ -614,12 +541,13 @@ def update(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
+ user = request.user.object if request.user.is_authenticated else None
if not state.config:
error_msg = f"🚨 Khoj is not configured.\nConfigure it via http://localhost:42110/config, plugins or by editing {state.config_file}."
logger.warning(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
try:
- configure_server(state.config, regenerate=force, search_type=t)
+ configure_server(state.config, regenerate=force, search_type=t, user=user)
except Exception as e:
error_msg = f"🚨 Failed to update server via API: {e}"
logger.error(error_msg, exc_info=True)
@@ -774,6 +702,7 @@ async def extract_references_and_questions(
n: int,
conversation_type: ConversationCommand = ConversationCommand.Default,
):
+ user = request.user.object if request.user.is_authenticated else None
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
@@ -781,7 +710,7 @@ async def extract_references_and_questions(
compiled_references: List[Any] = []
inferred_queries: List[str] = []
- if state.content_index is None:
+ if not EmbeddingsAdapters.user_has_embeddings(user=user):
logger.warning(
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
)
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index 6b42f29c..be9e8700 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -10,6 +10,7 @@ from khoj.utils.helpers import ConversationCommand, timer, log_telemetry
from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
+from database.models import KhojUser
logger = logging.getLogger(__name__)
@@ -40,11 +41,13 @@ def update_telemetry_state(
host: Optional[str] = None,
metadata: Optional[dict] = None,
):
+ user: KhojUser = request.user.object if request.user.is_authenticated else None
user_state = {
"client_host": request.client.host if request.client else None,
"user_agent": user_agent or "unknown",
"referer": referer or "unknown",
"host": host or "unknown",
+ "server_id": str(user.uuid) if user else None,
}
if metadata:
diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py
index a9656050..c2ef04ff 100644
--- a/src/khoj/routers/indexer.py
+++ b/src/khoj/routers/indexer.py
@@ -1,6 +1,7 @@
# Standard Packages
import logging
from typing import Optional, Union, Dict
+import asyncio
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request, Response, UploadFile
@@ -9,31 +10,30 @@ from khoj.routers.helpers import update_telemetry_state
# Internal Packages
from khoj.utils import state, constants
-from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.processor.notion.notion_to_jsonl import NotionToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
-from khoj.utils.rawconfig import ContentConfig, TextContentConfig
from khoj.search_type import text_search, image_search
from khoj.utils.yaml import save_config_to_file_updated_state
from khoj.utils.config import SearchModels
-from khoj.utils.constants import default_config
from khoj.utils.helpers import LRU, get_file_type
from khoj.utils.rawconfig import (
ContentConfig,
FullConfig,
SearchConfig,
)
-from khoj.search_filter.date_filter import DateFilter
-from khoj.search_filter.word_filter import WordFilter
-from khoj.search_filter.file_filter import FileFilter
from khoj.utils.config import (
ContentIndex,
SearchModels,
)
+from database.models import (
+ KhojUser,
+ GithubConfig,
+ NotionConfig,
+)
logger = logging.getLogger(__name__)
@@ -68,14 +68,14 @@ async def update(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
+ user = request.user.object if request.user.is_authenticated else None
if x_api_key != "secret":
raise HTTPException(status_code=401, detail="Invalid API Key")
- state.config_lock.acquire()
try:
logger.info(f"📬 Updating content index via API call by {client} client")
org_files: Dict[str, str] = {}
markdown_files: Dict[str, str] = {}
- pdf_files: Dict[str, str] = {}
+ pdf_files: Dict[str, bytes] = {}
plaintext_files: Dict[str, str] = {}
for file in files:
@@ -86,7 +86,7 @@ async def update(
elif file_type == "markdown":
dict_to_update = markdown_files
elif file_type == "pdf":
- dict_to_update = pdf_files
+ dict_to_update = pdf_files # type: ignore
elif file_type == "plaintext":
dict_to_update = plaintext_files
@@ -120,30 +120,31 @@ async def update(
github=None,
notion=None,
plaintext=None,
- plugins=None,
)
state.config.content_type = default_content_config
save_config_to_file_updated_state()
configure_search(state.search_models, state.config.search_type)
# Extract required fields from config
- state.content_index = configure_content(
+ loop = asyncio.get_event_loop()
+ state.content_index = await loop.run_in_executor(
+ None,
+ configure_content,
state.content_index,
state.config.content_type,
indexer_input.dict(),
state.search_models,
- regenerate=force,
- t=t,
- full_corpus=False,
+ force,
+ t,
+ False,
+ user,
)
-
+ logger.info(f"Finished processing batch indexing request")
except Exception as e:
logger.error(
f"🚨 Failed to {force} update {t} content index triggered via API call by {client} client: {e}",
exc_info=True,
)
- finally:
- state.config_lock.release()
update_telemetry_state(
request=request,
@@ -167,11 +168,6 @@ def configure_search(search_models: SearchModels, search_config: Optional[Search
if search_models is None:
search_models = SearchModels()
- # Initialize Search Models
- if search_config.asymmetric:
- logger.info("🔍 📜 Setting up text search model")
- search_models.text_search = text_search.initialize_model(search_config.asymmetric)
-
if search_config.image:
logger.info("🔍 🌄 Setting up image search model")
search_models.image_search = image_search.initialize_model(search_config.image)
@@ -187,16 +183,9 @@ def configure_content(
regenerate: bool = False,
t: Optional[Union[state.SearchType, str]] = None,
full_corpus: bool = True,
+ user: KhojUser = None,
) -> Optional[ContentIndex]:
- def has_valid_text_config(config: TextContentConfig):
- return config.input_files or config.input_filter
-
- # Run Validation Checks
- if content_config is None:
- logger.warning("🚨 No Content configuration available.")
- return None
- if content_index is None:
- content_index = ContentIndex()
+ content_index = ContentIndex()
if t in [type.value for type in state.SearchType]:
t = state.SearchType(t).value
@@ -209,59 +198,30 @@ def configure_content(
try:
# Initialize Org Notes Search
- if (
- (t == None or t == state.SearchType.Org.value)
- and ((content_config.org and has_valid_text_config(content_config.org)) or files["org"])
- and search_models.text_search
- ):
- if content_config.org == None:
- logger.info("🦄 No configuration for orgmode notes. Using default configuration.")
- default_configuration = default_config["content-type"]["org"] # type: ignore
- content_config.org = TextContentConfig(
- compressed_jsonl=default_configuration["compressed-jsonl"],
- embeddings_file=default_configuration["embeddings-file"],
- )
-
+ if (t == None or t == state.SearchType.Org.value) and files["org"]:
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
- content_index.org = text_search.setup(
+ text_search.setup(
OrgToJsonl,
files.get("org"),
- content_config.org,
- search_models.text_search.bi_encoder,
regenerate=regenerate,
- filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
+ user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup org: {e}", exc_info=True)
try:
# Initialize Markdown Search
- if (
- (t == None or t == state.SearchType.Markdown.value)
- and ((content_config.markdown and has_valid_text_config(content_config.markdown)) or files["markdown"])
- and search_models.text_search
- and files["markdown"]
- ):
- if content_config.markdown == None:
- logger.info("💎 No configuration for markdown notes. Using default configuration.")
- default_configuration = default_config["content-type"]["markdown"] # type: ignore
- content_config.markdown = TextContentConfig(
- compressed_jsonl=default_configuration["compressed-jsonl"],
- embeddings_file=default_configuration["embeddings-file"],
- )
-
+ if (t == None or t == state.SearchType.Markdown.value) and files["markdown"]:
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
- content_index.markdown = text_search.setup(
+ text_search.setup(
MarkdownToJsonl,
files.get("markdown"),
- content_config.markdown,
- search_models.text_search.bi_encoder,
regenerate=regenerate,
- filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
+ user=user,
)
except Exception as e:
@@ -269,30 +229,15 @@ def configure_content(
try:
# Initialize PDF Search
- if (
- (t == None or t == state.SearchType.Pdf.value)
- and ((content_config.pdf and has_valid_text_config(content_config.pdf)) or files["pdf"])
- and search_models.text_search
- and files["pdf"]
- ):
- if content_config.pdf == None:
- logger.info("🖨️ No configuration for pdf notes. Using default configuration.")
- default_configuration = default_config["content-type"]["pdf"] # type: ignore
- content_config.pdf = TextContentConfig(
- compressed_jsonl=default_configuration["compressed-jsonl"],
- embeddings_file=default_configuration["embeddings-file"],
- )
-
+ if (t == None or t == state.SearchType.Pdf.value) and files["pdf"]:
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
- content_index.pdf = text_search.setup(
+ text_search.setup(
PdfToJsonl,
files.get("pdf"),
- content_config.pdf,
- search_models.text_search.bi_encoder,
regenerate=regenerate,
- filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
+ user=user,
)
except Exception as e:
@@ -300,30 +245,15 @@ def configure_content(
try:
# Initialize Plaintext Search
- if (
- (t == None or t == state.SearchType.Plaintext.value)
- and ((content_config.plaintext and has_valid_text_config(content_config.plaintext)) or files["plaintext"])
- and search_models.text_search
- and files["plaintext"]
- ):
- if content_config.plaintext == None:
- logger.info("📄 No configuration for plaintext notes. Using default configuration.")
- default_configuration = default_config["content-type"]["plaintext"] # type: ignore
- content_config.plaintext = TextContentConfig(
- compressed_jsonl=default_configuration["compressed-jsonl"],
- embeddings_file=default_configuration["embeddings-file"],
- )
-
+ if (t == None or t == state.SearchType.Plaintext.value) and files["plaintext"]:
logger.info("📄 Setting up search for plaintext")
# Extract Entries, Generate Plaintext Embeddings
- content_index.plaintext = text_search.setup(
+ text_search.setup(
PlaintextToJsonl,
files.get("plaintext"),
- content_config.plaintext,
- search_models.text_search.bi_encoder,
regenerate=regenerate,
- filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
+ user=user,
)
except Exception as e:
@@ -331,7 +261,12 @@ def configure_content(
try:
# Initialize Image Search
- if (t == None or t == state.SearchType.Image.value) and content_config.image and search_models.image_search:
+ if (
+ (t == None or t == state.SearchType.Image.value)
+ and content_config
+ and content_config.image
+ and search_models.image_search
+ ):
logger.info("🌄 Setting up search for images")
# Extract Entries, Generate Image Embeddings
content_index.image = image_search.setup(
@@ -342,17 +277,17 @@ def configure_content(
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
try:
- if (t == None or t == state.SearchType.Github.value) and content_config.github and search_models.text_search:
+ github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
+ if (t == None or t == state.SearchType.Github.value) and github_config is not None:
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
- content_index.github = text_search.setup(
+ text_search.setup(
GithubToJsonl,
None,
- content_config.github,
- search_models.text_search.bi_encoder,
regenerate=regenerate,
- filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
+ user=user,
+ config=github_config,
)
except Exception as e:
@@ -360,42 +295,24 @@ def configure_content(
try:
# Initialize Notion Search
- if (t == None or t in state.SearchType.Notion.value) and content_config.notion and search_models.text_search:
+ notion_config = NotionConfig.objects.filter(user=user).first()
+ if (t == None or t in state.SearchType.Notion.value) and notion_config:
logger.info("🔌 Setting up search for notion")
- content_index.notion = text_search.setup(
+ text_search.setup(
NotionToJsonl,
None,
- content_config.notion,
- search_models.text_search.bi_encoder,
regenerate=regenerate,
- filters=[DateFilter(), WordFilter(), FileFilter()],
full_corpus=full_corpus,
+ user=user,
+ config=notion_config,
)
except Exception as e:
logger.error(f"🚨 Failed to setup GitHub: {e}", exc_info=True)
- try:
- # Initialize External Plugin Search
- if t == None and content_config.plugins and search_models.text_search:
- logger.info("🔌 Setting up search for plugins")
- content_index.plugins = {}
- for plugin_type, plugin_config in content_config.plugins.items():
- content_index.plugins[plugin_type] = text_search.setup(
- JsonlToJsonl,
- None,
- plugin_config,
- search_models.text_search.bi_encoder,
- regenerate=regenerate,
- filters=[DateFilter(), WordFilter(), FileFilter()],
- full_corpus=full_corpus,
- )
-
- except Exception as e:
- logger.error(f"🚨 Failed to setup Plugin: {e}", exc_info=True)
-
# Invalidate Query Cache
- state.query_cache = LRU()
+ if user:
+ state.query_cache[user.uuid] = LRU()
return content_index
@@ -412,44 +329,9 @@ def load_content(
if content_index is None:
content_index = ContentIndex()
- if content_config.org:
- logger.info("🦄 Loading orgmode notes")
- content_index.org = text_search.load(content_config.org, filters=[DateFilter(), WordFilter(), FileFilter()])
- if content_config.markdown:
- logger.info("💎 Loading markdown notes")
- content_index.markdown = text_search.load(
- content_config.markdown, filters=[DateFilter(), WordFilter(), FileFilter()]
- )
- if content_config.pdf:
- logger.info("🖨️ Loading pdf")
- content_index.pdf = text_search.load(content_config.pdf, filters=[DateFilter(), WordFilter(), FileFilter()])
- if content_config.plaintext:
- logger.info("📄 Loading plaintext")
- content_index.plaintext = text_search.load(
- content_config.plaintext, filters=[DateFilter(), WordFilter(), FileFilter()]
- )
if content_config.image:
logger.info("🌄 Loading images")
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
- if content_config.github:
- logger.info("🐙 Loading github")
- content_index.github = text_search.load(
- content_config.github, filters=[DateFilter(), WordFilter(), FileFilter()]
- )
- if content_config.notion:
- logger.info("🔌 Loading notion")
- content_index.notion = text_search.load(
- content_config.notion, filters=[DateFilter(), WordFilter(), FileFilter()]
- )
- if content_config.plugins:
- logger.info("🔌 Loading plugins")
- content_index.plugins = {}
- for plugin_type, plugin_config in content_config.plugins.items():
- content_index.plugins[plugin_type] = text_search.load(
- plugin_config, filters=[DateFilter(), WordFilter(), FileFilter()]
- )
-
- state.query_cache = LRU()
return content_index
diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py
index 492a263c..6c79e061 100644
--- a/src/khoj/routers/web_client.py
+++ b/src/khoj/routers/web_client.py
@@ -3,10 +3,20 @@ from fastapi import APIRouter
from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.templating import Jinja2Templates
-from khoj.utils.rawconfig import TextContentConfig, OpenAIProcessorConfig, FullConfig
+from starlette.authentication import requires
+from khoj.utils.rawconfig import (
+ TextContentConfig,
+ OpenAIProcessorConfig,
+ FullConfig,
+ GithubContentConfig,
+ GithubRepoConfig,
+ NotionContentConfig,
+)
# Internal Packages
from khoj.utils import constants, state
+from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
+from database.models import KhojUser, LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
import json
@@ -29,10 +39,23 @@ def chat_page(request: Request):
return templates.TemplateResponse("chat.html", context={"request": request, "demo": state.demo})
+def map_config_to_object(content_type: str):
+ if content_type == "org":
+ return LocalOrgConfig
+ if content_type == "markdown":
+ return LocalMarkdownConfig
+ if content_type == "pdf":
+ return LocalPdfConfig
+ if content_type == "plaintext":
+ return LocalPlaintextConfig
+
+
if not state.demo:
@web_client.get("/config", response_class=HTMLResponse)
def config_page(request: Request):
+ user = request.user.object if request.user.is_authenticated else None
+ enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
default_full_config = FullConfig(
content_type=None,
search_type=None,
@@ -41,13 +64,13 @@ if not state.demo:
current_config = state.config or json.loads(default_full_config.json())
successfully_configured = {
- "pdf": False,
- "markdown": False,
- "org": False,
+ "pdf": ("pdf" in enabled_content),
+ "markdown": ("markdown" in enabled_content),
+ "org": ("org" in enabled_content),
"image": False,
- "github": False,
- "notion": False,
- "plaintext": False,
+ "github": ("github" in enabled_content),
+ "notion": ("notion" in enabled_content),
+ "plaintext": ("plaintext" in enabled_content),
"enable_offline_model": False,
"conversation_openai": False,
"conversation_gpt4all": False,
@@ -56,13 +79,7 @@ if not state.demo:
if state.content_index:
successfully_configured.update(
{
- "pdf": state.content_index.pdf is not None,
- "markdown": state.content_index.markdown is not None,
- "org": state.content_index.org is not None,
"image": state.content_index.image is not None,
- "github": state.content_index.github is not None,
- "notion": state.content_index.notion is not None,
- "plaintext": state.content_index.plaintext is not None,
}
)
@@ -84,22 +101,29 @@ if not state.demo:
)
@web_client.get("/config/content_type/github", response_class=HTMLResponse)
+ @requires(["authenticated"])
def github_config_page(request: Request):
- default_copy = constants.default_config.copy()
- default_github = default_copy["content-type"]["github"] # type: ignore
+ user = request.user.object if request.user.is_authenticated else None
+ current_github_config = get_user_github_config(user)
- default_config = TextContentConfig(
- compressed_jsonl=default_github["compressed-jsonl"],
- embeddings_file=default_github["embeddings-file"],
- )
-
- current_config = (
- state.config.content_type.github
- if state.config and state.config.content_type and state.config.content_type.github
- else default_config
- )
-
- current_config = json.loads(current_config.json())
+ if current_github_config:
+ raw_repos = current_github_config.githubrepoconfig.all()
+ repos = []
+ for repo in raw_repos:
+ repos.append(
+ GithubRepoConfig(
+ name=repo.name,
+ owner=repo.owner,
+ branch=repo.branch,
+ )
+ )
+ current_config = GithubContentConfig(
+ pat_token=current_github_config.pat_token,
+ repos=repos,
+ )
+ current_config = json.loads(current_config.json())
+ else:
+ current_config = {} # type: ignore
return templates.TemplateResponse(
"content_type_github_input.html", context={"request": request, "current_config": current_config}
@@ -107,18 +131,11 @@ if not state.demo:
@web_client.get("/config/content_type/notion", response_class=HTMLResponse)
def notion_config_page(request: Request):
- default_copy = constants.default_config.copy()
- default_notion = default_copy["content-type"]["notion"] # type: ignore
+ user = request.user.object if request.user.is_authenticated else None
+ current_notion_config = get_user_notion_config(user)
- default_config = TextContentConfig(
- compressed_jsonl=default_notion["compressed-jsonl"],
- embeddings_file=default_notion["embeddings-file"],
- )
-
- current_config = (
- state.config.content_type.notion
- if state.config and state.config.content_type and state.config.content_type.notion
- else default_config
+ current_config = NotionContentConfig(
+ token=current_notion_config.token if current_notion_config else "",
)
current_config = json.loads(current_config.json())
@@ -132,18 +149,16 @@ if not state.demo:
if content_type not in VALID_TEXT_CONTENT_TYPES:
return templates.TemplateResponse("config.html", context={"request": request})
- default_copy = constants.default_config.copy()
- default_content_type = default_copy["content-type"][content_type] # type: ignore
+ object = map_config_to_object(content_type)
+ user = request.user.object if request.user.is_authenticated else None
+ config = object.objects.filter(user=user).first()
+ if config == None:
+ config = object.objects.create(user=user)
- default_config = TextContentConfig(
- compressed_jsonl=default_content_type["compressed-jsonl"],
- embeddings_file=default_content_type["embeddings-file"],
- )
-
- current_config = (
- state.config.content_type[content_type]
- if state.config and state.config.content_type and state.config.content_type[content_type] # type: ignore
- else default_config
+ current_config = TextContentConfig(
+ input_files=config.input_files,
+ input_filter=config.input_filter,
+ index_heading_entries=config.index_heading_entries,
)
current_config = json.loads(current_config.json())
diff --git a/src/khoj/search_filter/base_filter.py b/src/khoj/search_filter/base_filter.py
index 470f7341..ae596587 100644
--- a/src/khoj/search_filter/base_filter.py
+++ b/src/khoj/search_filter/base_filter.py
@@ -1,16 +1,9 @@
# Standard Packages
from abc import ABC, abstractmethod
-from typing import List, Set, Tuple
-
-# Internal Packages
-from khoj.utils.rawconfig import Entry
+from typing import List
class BaseFilter(ABC):
- @abstractmethod
- def load(self, entries: List[Entry], *args, **kwargs):
- ...
-
@abstractmethod
def get_filter_terms(self, query: str) -> List[str]:
...
@@ -18,10 +11,6 @@ class BaseFilter(ABC):
def can_filter(self, raw_query: str) -> bool:
return len(self.get_filter_terms(raw_query)) > 0
- @abstractmethod
- def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
- ...
-
@abstractmethod
def defilter(self, query: str) -> str:
...
diff --git a/src/khoj/search_filter/date_filter.py b/src/khoj/search_filter/date_filter.py
index 39e7bec3..88c70101 100644
--- a/src/khoj/search_filter/date_filter.py
+++ b/src/khoj/search_filter/date_filter.py
@@ -25,72 +25,42 @@ class DateFilter(BaseFilter):
# - dt>="last week"
# - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']"
+ raw_date_regex = r"\d{4}-\d{2}-\d{2}"
def __init__(self, entry_key="compiled"):
self.entry_key = entry_key
self.date_to_entry_ids = defaultdict(set)
self.cache = LRU()
- def load(self, entries, *args, **kwargs):
- with timer("Created date filter index", logger):
- for id, entry in enumerate(entries):
- # Extract dates from entry
- for date_in_entry_string in re.findall(r"\d{4}-\d{2}-\d{2}", getattr(entry, self.entry_key)):
- # Convert date string in entry to unix timestamp
- try:
- date_in_entry = datetime.strptime(date_in_entry_string, "%Y-%m-%d").timestamp()
- except ValueError:
- continue
- except OSError:
- logger.debug(f"OSError: Ignoring unprocessable date in entry: {date_in_entry_string}")
- continue
- self.date_to_entry_ids[date_in_entry].add(id)
+ def extract_dates(self, content):
+ pattern_matched_dates = re.findall(self.raw_date_regex, content)
+
+ # Filter down to valid dates
+ valid_dates = []
+ for date_str in pattern_matched_dates:
+ try:
+ valid_dates.append(datetime.strptime(date_str, "%Y-%m-%d"))
+ except ValueError:
+ continue
+
+ return valid_dates
def get_filter_terms(self, query: str) -> List[str]:
"Get all filter terms in query"
return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)]
+ def get_query_date_range(self, query) -> List:
+ with timer("Extract date range to filter from query", logger):
+ query_daterange = self.extract_date_range(query)
+
+ return query_daterange
+
def defilter(self, query):
# remove date range filter from query
query = re.sub(rf"\s+{self.date_regex}", " ", query)
query = re.sub(r"\s{2,}", " ", query).strip() # remove multiple spaces
return query
- def apply(self, query, entries):
- "Find entries containing any dates that fall within date range specified in query"
- # extract date range specified in date filter of query
- with timer("Extract date range to filter from query", logger):
- query_daterange = self.extract_date_range(query)
-
- # if no date in query, return all entries
- if query_daterange == []:
- return query, set(range(len(entries)))
-
- query = self.defilter(query)
-
- # return results from cache if exists
- cache_key = tuple(query_daterange)
- if cache_key in self.cache:
- logger.debug(f"Return date filter results from cache")
- entries_to_include = self.cache[cache_key]
- return query, entries_to_include
-
- if not self.date_to_entry_ids:
- self.load(entries)
-
- # find entries containing any dates that fall with date range specified in query
- with timer("Mark entries satisfying filter", logger):
- entries_to_include = set()
- for date_in_entry in self.date_to_entry_ids.keys():
- # Check if date in entry is within date range specified in query
- if query_daterange[0] <= date_in_entry < query_daterange[1]:
- entries_to_include |= self.date_to_entry_ids[date_in_entry]
-
- # cache results
- self.cache[cache_key] = entries_to_include
-
- return query, entries_to_include
-
def extract_date_range(self, query):
# find date range filter in query
date_range_matches = re.findall(self.date_regex, query)
@@ -138,6 +108,15 @@ class DateFilter(BaseFilter):
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
return []
else:
+ # If the first element is 0, replace it with None
+
+ if effective_date_range[0] == 0:
+ effective_date_range[0] = None
+
+ # If the second element is inf, replace it with None
+ if effective_date_range[1] == inf:
+ effective_date_range[1] = None
+
return effective_date_range
def parse(self, date_str, relative_base=None):
diff --git a/src/khoj/search_filter/file_filter.py b/src/khoj/search_filter/file_filter.py
index 420bf9e7..291838ea 100644
--- a/src/khoj/search_filter/file_filter.py
+++ b/src/khoj/search_filter/file_filter.py
@@ -21,62 +21,13 @@ class FileFilter(BaseFilter):
self.file_to_entry_map = defaultdict(set)
self.cache = LRU()
- def load(self, entries, *args, **kwargs):
- with timer("Created file filter index", logger):
- for id, entry in enumerate(entries):
- self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
-
def get_filter_terms(self, query: str) -> List[str]:
"Get all filter terms in query"
- return [f'file:"{term}"' for term in re.findall(self.file_filter_regex, query)]
+ return [f"{self.convert_to_regex(term)}" for term in re.findall(self.file_filter_regex, query)]
+
+ def convert_to_regex(self, file_filter: str) -> str:
+ "Convert file filter to regex"
+ return file_filter.replace(".", r"\.").replace("*", r".*")
def defilter(self, query: str) -> str:
return re.sub(self.file_filter_regex, "", query).strip()
-
- def apply(self, query, entries):
- # Extract file filters from raw query
- with timer("Extract files_to_search from query", logger):
- raw_files_to_search = re.findall(self.file_filter_regex, query)
- if not raw_files_to_search:
- return query, set(range(len(entries)))
-
- # Convert simple file filters with no path separator into regex
- # e.g. "file:notes.org" -> "file:.*notes.org"
- files_to_search = []
- for file in sorted(raw_files_to_search):
- if "/" not in file and "\\" not in file and "*" not in file:
- files_to_search += [f"*{file}"]
- else:
- files_to_search += [file]
-
- # Remove filter terms from original query
- query = self.defilter(query)
-
- # Return item from cache if exists
- cache_key = tuple(files_to_search)
- if cache_key in self.cache:
- logger.debug(f"Return file filter results from cache")
- included_entry_indices = self.cache[cache_key]
- return query, included_entry_indices
-
- if not self.file_to_entry_map:
- self.load(entries, regenerate=False)
-
- # Mark entries that contain any blocked_words for exclusion
- with timer("Mark entries satisfying filter", logger):
- included_entry_indices = set.union(
- *[
- self.file_to_entry_map[entry_file]
- for entry_file in self.file_to_entry_map.keys()
- for search_file in files_to_search
- if fnmatch.fnmatch(entry_file, search_file)
- ],
- set(),
- )
- if not included_entry_indices:
- return query, {}
-
- # Cache results
- self.cache[cache_key] = included_entry_indices
-
- return query, included_entry_indices
diff --git a/src/khoj/search_filter/word_filter.py b/src/khoj/search_filter/word_filter.py
index ebf64b34..b2053dbe 100644
--- a/src/khoj/search_filter/word_filter.py
+++ b/src/khoj/search_filter/word_filter.py
@@ -6,7 +6,7 @@ from typing import List
# Internal Packages
from khoj.search_filter.base_filter import BaseFilter
-from khoj.utils.helpers import LRU, timer
+from khoj.utils.helpers import LRU
logger = logging.getLogger(__name__)
@@ -22,21 +22,6 @@ class WordFilter(BaseFilter):
self.word_to_entry_index = defaultdict(set)
self.cache = LRU()
- def load(self, entries, *args, **kwargs):
- with timer("Created word filter index", logger):
- self.cache = {} # Clear cache on filter (re-)load
- entry_splitter = (
- r",|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'"
- )
- # Create map of words to entries they exist in
- for entry_index, entry in enumerate(entries):
- for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
- if word == "":
- continue
- self.word_to_entry_index[word].add(entry_index)
-
- return self.word_to_entry_index
-
def get_filter_terms(self, query: str) -> List[str]:
"Get all filter terms in query"
required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)]
@@ -45,47 +30,3 @@ class WordFilter(BaseFilter):
def defilter(self, query: str) -> str:
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
-
- def apply(self, query, entries):
- "Find entries containing required and not blocked words specified in query"
- # Separate natural query from required, blocked words filters
- with timer("Extract required, blocked filters from query", logger):
- required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
- blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
- query = self.defilter(query)
-
- if len(required_words) == 0 and len(blocked_words) == 0:
- return query, set(range(len(entries)))
-
- # Return item from cache if exists
- cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
- if cache_key in self.cache:
- logger.debug(f"Return word filter results from cache")
- included_entry_indices = self.cache[cache_key]
- return query, included_entry_indices
-
- if not self.word_to_entry_index:
- self.load(entries, regenerate=False)
-
- # mark entries that contain all required_words for inclusion
- with timer("Mark entries satisfying filter", logger):
- entries_with_all_required_words = set(range(len(entries)))
- if len(required_words) > 0:
- entries_with_all_required_words = set.intersection(
- *[self.word_to_entry_index.get(word, set()) for word in required_words]
- )
-
- # mark entries that contain any blocked_words for exclusion
- entries_with_any_blocked_words = set()
- if len(blocked_words) > 0:
- entries_with_any_blocked_words = set.union(
- *[self.word_to_entry_index.get(word, set()) for word in blocked_words]
- )
-
- # get entries satisfying inclusion and exclusion filters
- included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
-
- # Cache results
- self.cache[cache_key] = included_entry_indices
-
- return query, included_entry_indices
diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py
index 2890baa9..36d6a791 100644
--- a/src/khoj/search_type/text_search.py
+++ b/src/khoj/search_type/text_search.py
@@ -2,25 +2,39 @@
import logging
import math
from pathlib import Path
-from typing import List, Tuple, Type, Union
+from typing import List, Tuple, Type, Union, Dict
# External Packages
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util
-from khoj.processor.text_to_jsonl import TextToJsonl
-from khoj.search_filter.base_filter import BaseFilter
+
+from asgiref.sync import sync_to_async
+
# Internal Packages
from khoj.utils import state
-from khoj.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
-from khoj.utils.config import TextContent, TextSearchModel
+from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
+from khoj.utils.config import TextSearchModel
from khoj.utils.models import BaseEncoder
-from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, TextConfigBase, Entry
+from khoj.utils.state import SearchType
+from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, Entry
from khoj.utils.jsonl import load_jsonl
-
+from khoj.processor.text_to_jsonl import TextEmbeddings
+from database.adapters import EmbeddingsAdapters
+from database.models import KhojUser, Embeddings
logger = logging.getLogger(__name__)
+search_type_to_embeddings_type = {
+ SearchType.Org.value: Embeddings.EmbeddingsType.ORG,
+ SearchType.Markdown.value: Embeddings.EmbeddingsType.MARKDOWN,
+ SearchType.Plaintext.value: Embeddings.EmbeddingsType.PLAINTEXT,
+ SearchType.Pdf.value: Embeddings.EmbeddingsType.PDF,
+ SearchType.Github.value: Embeddings.EmbeddingsType.GITHUB,
+ SearchType.Notion.value: Embeddings.EmbeddingsType.NOTION,
+ SearchType.All.value: None,
+}
+
def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text"
@@ -117,171 +131,102 @@ def load_embeddings(
async def query(
+ user: KhojUser,
raw_query: str,
- search_model: TextSearchModel,
- content: TextContent,
+ type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False,
score_threshold: float = -math.inf,
- dedupe: bool = True,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
- if (
- content.entries is None
- or len(content.entries) == 0
- or content.corpus_embeddings is None
- or len(content.corpus_embeddings) == 0
- ):
- return [], []
- query, entries, corpus_embeddings = raw_query, content.entries, content.corpus_embeddings
+ file_type = search_type_to_embeddings_type[type.value]
- # Filter query, entries and embeddings before semantic search
- query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, content.filters)
-
- # If no entries left after filtering, return empty results
- if entries is None or len(entries) == 0:
- return [], []
- # If query only had filters it'll be empty now. So short-circuit and return results.
- if query.strip() == "":
- hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)]
- return hits, entries
+ query = raw_query
# Encode the query using the bi-encoder
if question_embedding is None:
with timer("Query Encode Time", logger, state.device):
- question_embedding = search_model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
- question_embedding = util.normalize_embeddings(question_embedding)
+ question_embedding = state.embeddings_model.embed_query(query)
# Find relevant entries for the query
- top_k = min(len(entries), search_model.top_k or 10) # top_k hits can't be more than the total entries in corpus
+ top_k = 10
with timer("Search Time", logger, state.device):
- hits = util.semantic_search(question_embedding, corpus_embeddings, top_k, score_function=util.dot_score)[0]
+ hits = EmbeddingsAdapters.search_with_embeddings(
+ user=user,
+ embeddings=question_embedding,
+ max_results=top_k,
+ file_type_filter=file_type,
+ raw_query=raw_query,
+ ).all()
+ hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
+ return hits
+
+
+def collate_results(hits, dedupe=True):
+ hit_ids = set()
+ for hit in hits:
+ if dedupe and hit.corpus_id in hit_ids:
+ continue
+
+ else:
+ hit_ids.add(hit.corpus_id)
+ yield SearchResponse.parse_obj(
+ {
+ "entry": hit.raw,
+ "score": hit.distance,
+ "additional": {
+ "file": hit.file_path,
+ "compiled": hit.compiled,
+ "heading": hit.heading,
+ },
+ }
+ )
+
+
+def rerank_and_sort_results(hits, query):
# Score all retrieved entries using the cross-encoder
- if rank_results and search_model.cross_encoder:
- hits = cross_encoder_score(search_model.cross_encoder, query, entries, hits)
+ hits = cross_encoder_score(query, hits)
- # Filter results by score threshold
- hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
+ # Sort results by cross-encoder score followed by bi-encoder score
+ hits = sort_results(rank_results=True, hits=hits)
- # Order results by cross-encoder score followed by bi-encoder score
- hits = sort_results(rank_results, hits)
-
- # Deduplicate entries by raw entry text before showing to users
- if dedupe:
- hits = deduplicate_results(entries, hits)
-
- return hits, entries
-
-
-def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]:
- return [
- SearchResponse.parse_obj(
- {
- "entry": entries[hit["corpus_id"]].raw,
- "score": f"{hit.get('cross-score') or hit.get('score')}",
- "additional": {
- "file": entries[hit["corpus_id"]].file,
- "compiled": entries[hit["corpus_id"]].compiled,
- "heading": entries[hit["corpus_id"]].heading,
- },
- }
- )
- for hit in hits[0:count]
- ]
+ return hits
def setup(
- text_to_jsonl: Type[TextToJsonl],
+ text_to_jsonl: Type[TextEmbeddings],
files: dict[str, str],
- config: TextConfigBase,
- bi_encoder: BaseEncoder,
regenerate: bool,
- filters: List[BaseFilter] = [],
- normalize: bool = True,
full_corpus: bool = True,
-) -> TextContent:
- # Map notes in text files to (compressed) JSONL formatted file
- config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
- previous_entries = []
- if config.compressed_jsonl.exists() and not regenerate:
- previous_entries = extract_entries(config.compressed_jsonl)
- entries_with_indices = text_to_jsonl(config).process(
- previous_entries=previous_entries, files=files, full_corpus=full_corpus
- )
-
- # Extract Updated Entries
- entries = extract_entries(config.compressed_jsonl)
- if is_none_or_empty(entries):
- config_params = ", ".join([f"{key}={value}" for key, value in config.dict().items()])
- raise ValueError(
- f"No valid entries found in specified configuration: {config_params}, with files: {files.keys()}"
+ user: KhojUser = None,
+ config=None,
+) -> None:
+ if config:
+ num_new_embeddings, num_deleted_embeddings = text_to_jsonl(config).process(
+ files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
+ )
+ else:
+ num_new_embeddings, num_deleted_embeddings = text_to_jsonl().process(
+ files=files, full_corpus=full_corpus, user=user, regenerate=regenerate
)
- # Compute or Load Embeddings
- config.embeddings_file = resolve_absolute_path(config.embeddings_file)
- corpus_embeddings = compute_embeddings(
- entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize
+ file_names = [file_name for file_name in files]
+
+ logger.info(
+ f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
)
- for filter in filters:
- filter.load(entries, regenerate=regenerate)
- return TextContent(entries, corpus_embeddings, filters)
-
-
-def load(
- config: TextConfigBase,
- filters: List[BaseFilter] = [],
-) -> TextContent:
- # Map notes in text files to (compressed) JSONL formatted file
- config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
- entries = extract_entries(config.compressed_jsonl)
-
- # Compute or Load Embeddings
- config.embeddings_file = resolve_absolute_path(config.embeddings_file)
- corpus_embeddings = load_embeddings(config.embeddings_file)
-
- for filter in filters:
- filter.load(entries, regenerate=False)
-
- return TextContent(entries, corpus_embeddings, filters)
-
-
-def apply_filters(
- query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]
-) -> Tuple[str, List[Entry], torch.Tensor]:
- """Filter query, entries and embeddings before semantic search"""
-
- with timer("Total Filter Time", logger, state.device):
- included_entry_indices = set(range(len(entries)))
- filters_in_query = [filter for filter in filters if filter.can_filter(query)]
- for filter in filters_in_query:
- query, included_entry_indices_by_filter = filter.apply(query, entries)
- included_entry_indices.intersection_update(included_entry_indices_by_filter)
-
- # Get entries (and associated embeddings) satisfying all filters
- if not included_entry_indices:
- return "", [], torch.tensor([], device=state.device)
- else:
- entries = [entries[id] for id in included_entry_indices]
- corpus_embeddings = torch.index_select(
- corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)
- )
-
- return query, entries, corpus_embeddings
-
-
-def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]:
+def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:
"""Score all retrieved entries using the cross-encoder"""
with timer("Cross-Encoder Predict Time", logger, state.device):
- cross_inp = [[query, entries[hit["corpus_id"]].compiled] for hit in hits]
- cross_scores = cross_encoder.predict(cross_inp)
+ cross_scores = state.cross_encoder_model.predict(query, hits)
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
- hits[idx]["cross-score"] = cross_scores[idx]
+ hits[idx]["cross_score"] = cross_scores[idx]
return hits
@@ -291,23 +236,5 @@ def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
if rank_results:
- hits.sort(key=lambda x: x["cross-score"], reverse=True) # sort by cross-encoder score
- return hits
-
-
-def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]:
- """Deduplicate entries by raw entry text before showing to users
- Compiled entries are split by max tokens supported by ML models.
- This can result in duplicate hits, entries shown to user."""
-
- with timer("Deduplication Time", logger, state.device):
- seen, original_hits_count = set(), len(hits)
- hits = [
- hit
- for hit in hits
- if entries[hit["corpus_id"]].raw not in seen and not seen.add(entries[hit["corpus_id"]].raw) # type: ignore[func-returns-value]
- ]
- duplicate_hits = original_hits_count - len(hits)
-
- logger.debug(f"Removed {duplicate_hits} duplicates")
+ hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
return hits
diff --git a/src/khoj/utils/cli.py b/src/khoj/utils/cli.py
index 1d6106cb..c72320a1 100644
--- a/src/khoj/utils/cli.py
+++ b/src/khoj/utils/cli.py
@@ -2,6 +2,7 @@
import argparse
import pathlib
from importlib.metadata import version
+import os
# Internal Packages
from khoj.utils.helpers import resolve_absolute_path
@@ -34,6 +35,12 @@ def cli(args=None):
)
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
parser.add_argument("--demo", action="store_true", default=False, help="Run Khoj in demo mode")
+ parser.add_argument(
+ "--anonymous-mode",
+ action="store_true",
+ default=False,
+ help="Run Khoj in anonymous mode. This does not require any login for connecting users.",
+ )
args = parser.parse_args(args)
@@ -51,6 +58,8 @@ def cli(args=None):
else:
args = run_migrations(args)
args.config = parse_config_from_file(args.config_file)
+ if os.environ.get("DEBUG"):
+ args.config.app.should_log_telemetry = False
return args
diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py
index 5b3b9f6e..ee5b4f9f 100644
--- a/src/khoj/utils/config.py
+++ b/src/khoj/utils/config.py
@@ -41,9 +41,7 @@ class ProcessorType(str, Enum):
@dataclass
class TextContent:
- entries: List[Entry]
- corpus_embeddings: torch.Tensor
- filters: List[BaseFilter]
+ enabled: bool
@dataclass
@@ -67,21 +65,13 @@ class ImageSearchModel:
@dataclass
class ContentIndex:
- org: Optional[TextContent] = None
- markdown: Optional[TextContent] = None
- pdf: Optional[TextContent] = None
- github: Optional[TextContent] = None
- notion: Optional[TextContent] = None
image: Optional[ImageContent] = None
- plaintext: Optional[TextContent] = None
- plugins: Optional[Dict[str, TextContent]] = None
@dataclass
class SearchModels:
text_search: Optional[TextSearchModel] = None
image_search: Optional[ImageSearchModel] = None
- plugin_search: Optional[Dict[str, TextSearchModel]] = None
@dataclass
diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py
index 9ed97798..181dee04 100644
--- a/src/khoj/utils/constants.py
+++ b/src/khoj/utils/constants.py
@@ -5,6 +5,7 @@ web_directory = app_root_directory / "khoj/interface/web/"
empty_escape_sequences = "\n|\r|\t| "
app_env_filepath = "~/.khoj/env"
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
+content_directory = "~/.khoj/content/"
empty_config = {
"content-type": {
diff --git a/src/khoj/utils/fs_syncer.py b/src/khoj/utils/fs_syncer.py
index 44fc70ad..fc7e4a2d 100644
--- a/src/khoj/utils/fs_syncer.py
+++ b/src/khoj/utils/fs_syncer.py
@@ -5,29 +5,39 @@ from typing import Optional
from bs4 import BeautifulSoup
from khoj.utils.helpers import get_absolute_path, is_none_or_empty
-from khoj.utils.rawconfig import TextContentConfig, ContentConfig
+from khoj.utils.rawconfig import TextContentConfig
from khoj.utils.config import SearchType
+from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig
logger = logging.getLogger(__name__)
-def collect_files(config: ContentConfig, search_type: Optional[SearchType] = SearchType.All):
+def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict:
files = {}
- if config is None:
- return files
-
if search_type == SearchType.All or search_type == SearchType.Org:
- files["org"] = get_org_files(config.org) if config.org else {}
+ org_config = LocalOrgConfig.objects.filter(user=user).first()
+ files["org"] = get_org_files(construct_config_from_db(org_config)) if org_config else {}
if search_type == SearchType.All or search_type == SearchType.Markdown:
- files["markdown"] = get_markdown_files(config.markdown) if config.markdown else {}
+ markdown_config = LocalMarkdownConfig.objects.filter(user=user).first()
+ files["markdown"] = get_markdown_files(construct_config_from_db(markdown_config)) if markdown_config else {}
if search_type == SearchType.All or search_type == SearchType.Plaintext:
- files["plaintext"] = get_plaintext_files(config.plaintext) if config.plaintext else {}
+ plaintext_config = LocalPlaintextConfig.objects.filter(user=user).first()
+ files["plaintext"] = get_plaintext_files(construct_config_from_db(plaintext_config)) if plaintext_config else {}
if search_type == SearchType.All or search_type == SearchType.Pdf:
- files["pdf"] = get_pdf_files(config.pdf) if config.pdf else {}
+ pdf_config = LocalPdfConfig.objects.filter(user=user).first()
+ files["pdf"] = get_pdf_files(construct_config_from_db(pdf_config)) if pdf_config else {}
return files
+def construct_config_from_db(db_config) -> TextContentConfig:
+ return TextContentConfig(
+ input_files=db_config.input_files,
+ input_filter=db_config.input_filter,
+ index_heading_entries=db_config.index_heading_entries,
+ )
+
+
def get_plaintext_files(config: TextContentConfig) -> dict[str, str]:
def is_plaintextfile(file: str):
"Check if file is plaintext file"
diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py
index 9209ff67..e41791f9 100644
--- a/src/khoj/utils/helpers.py
+++ b/src/khoj/utils/helpers.py
@@ -209,10 +209,12 @@ def log_telemetry(
if not app_config or not app_config.should_log_telemetry:
return []
+ if properties.get("server_id") is None:
+ properties["server_id"] = get_server_id()
+
# Populate telemetry data to log
request_body = {
"telemetry_type": telemetry_type,
- "server_id": get_server_id(),
"server_version": version("khoj-assistant"),
"os": platform.system(),
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py
index f7c42266..5d2b3ce4 100644
--- a/src/khoj/utils/rawconfig.py
+++ b/src/khoj/utils/rawconfig.py
@@ -1,13 +1,14 @@
# System Packages
import json
from pathlib import Path
-from typing import List, Dict, Optional, Union, Any
+from typing import List, Dict, Optional
+import uuid
# External Packages
-from pydantic import BaseModel, validator
+from pydantic import BaseModel
# Internal Packages
-from khoj.utils.helpers import to_snake_case_from_dash, is_none_or_empty
+from khoj.utils.helpers import to_snake_case_from_dash
class ConfigBase(BaseModel):
@@ -27,7 +28,7 @@ class TextConfigBase(ConfigBase):
embeddings_file: Path
-class TextContentConfig(TextConfigBase):
+class TextContentConfig(ConfigBase):
input_files: Optional[List[Path]]
input_filter: Optional[List[str]]
index_heading_entries: Optional[bool] = False
@@ -39,12 +40,12 @@ class GithubRepoConfig(ConfigBase):
branch: Optional[str] = "master"
-class GithubContentConfig(TextConfigBase):
+class GithubContentConfig(ConfigBase):
pat_token: str
repos: List[GithubRepoConfig]
-class NotionContentConfig(TextConfigBase):
+class NotionContentConfig(ConfigBase):
token: str
@@ -63,7 +64,6 @@ class ContentConfig(ConfigBase):
pdf: Optional[TextContentConfig]
plaintext: Optional[TextContentConfig]
github: Optional[GithubContentConfig]
- plugins: Optional[Dict[str, TextContentConfig]]
notion: Optional[NotionContentConfig]
@@ -122,7 +122,8 @@ class FullConfig(ConfigBase):
class SearchResponse(ConfigBase):
entry: str
- score: str
+ score: float
+ cross_score: Optional[float]
additional: Optional[dict]
@@ -131,14 +132,21 @@ class Entry:
compiled: str
heading: Optional[str]
file: Optional[str]
+ corpus_id: str
def __init__(
- self, raw: str = None, compiled: str = None, heading: Optional[str] = None, file: Optional[str] = None
+ self,
+ raw: str = None,
+ compiled: str = None,
+ heading: Optional[str] = None,
+ file: Optional[str] = None,
+ corpus_id: uuid.UUID = None,
):
self.raw = raw
self.compiled = compiled
self.heading = heading
self.file = file
+ self.corpus_id = str(corpus_id)
def to_json(self) -> str:
return json.dumps(self.__dict__, ensure_ascii=False)
@@ -153,4 +161,5 @@ class Entry:
compiled=dictionary["compiled"],
file=dictionary.get("file", None),
heading=dictionary.get("heading", None),
+ corpus_id=dictionary.get("corpus_id", None),
)
diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py
index 5ac8a838..d6169d2a 100644
--- a/src/khoj/utils/state.py
+++ b/src/khoj/utils/state.py
@@ -2,6 +2,7 @@
import threading
from typing import List, Dict
from packaging import version
+from collections import defaultdict
# External Packages
import torch
@@ -12,10 +13,13 @@ from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import FullConfig
+from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
# Application Global State
config = FullConfig()
search_models = SearchModels()
+embeddings_model = EmbeddingsModel()
+cross_encoder_model = CrossEncoderModel()
content_index = ContentIndex()
processor_config = ProcessorConfigModel()
config_file: Path = None
@@ -23,14 +27,14 @@ verbose: int = 0
host: str = None
port: int = None
cli_args: List[str] = None
-query_cache = LRU()
+query_cache: Dict[str, LRU] = defaultdict(LRU)
config_lock = threading.Lock()
chat_lock = threading.Lock()
SearchType = utils_config.SearchType
telemetry: List[Dict[str, str]] = []
-previous_query: str = None
demo: bool = False
khoj_version: str = None
+anonymous_mode: bool = False
if torch.cuda.is_available():
# Use CUDA GPU
diff --git a/tests/conftest.py b/tests/conftest.py
index 4f7dfb10..5f515ef1 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,15 +1,19 @@
# External Packages
import os
-from copy import deepcopy
from fastapi.testclient import TestClient
from pathlib import Path
import pytest
from fastapi.staticfiles import StaticFiles
+from fastapi import FastAPI
+import factory
+import os
+from fastapi import FastAPI
+
+app = FastAPI()
+
# Internal Packages
-from app.main import app
from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
-from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchModels
@@ -22,8 +26,6 @@ from khoj.utils.rawconfig import (
OpenAIProcessorConfig,
ProcessorConfig,
TextContentConfig,
- GithubContentConfig,
- GithubRepoConfig,
ImageContentConfig,
SearchConfig,
TextSearchConfig,
@@ -31,11 +33,31 @@ from khoj.utils.rawconfig import (
)
from khoj.utils import state, fs_syncer
from khoj.routers.indexer import configure_content
-from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
-from khoj.search_filter.date_filter import DateFilter
-from khoj.search_filter.word_filter import WordFilter
-from khoj.search_filter.file_filter import FileFilter
+from database.models import (
+ LocalOrgConfig,
+ LocalMarkdownConfig,
+ LocalPlaintextConfig,
+ LocalPdfConfig,
+ GithubConfig,
+ KhojUser,
+ GithubRepoConfig,
+)
+
+
+@pytest.fixture(autouse=True)
+def enable_db_access_for_all_tests(db):
+ pass
+
+
+class UserFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = KhojUser
+
+ username = factory.Faker("name")
+ email = factory.Faker("email")
+ password = factory.Faker("password")
+ uuid = factory.Faker("uuid4")
@pytest.fixture(scope="session")
@@ -67,17 +89,28 @@ def search_config() -> SearchConfig:
return search_config
+@pytest.mark.django_db
+@pytest.fixture
+def default_user():
+ return UserFactory()
+
+
@pytest.fixture(scope="session")
def search_models(search_config: SearchConfig):
search_models = SearchModels()
- search_models.text_search = text_search.initialize_model(search_config.asymmetric)
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
-@pytest.fixture(scope="session")
-def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
+@pytest.fixture
+def anyio_backend():
+ return "asyncio"
+
+
+@pytest.mark.django_db
+@pytest.fixture(scope="function")
+def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser):
content_dir = tmp_path_factory.mktemp("content")
# Generate Image Embeddings from Test Images
@@ -92,94 +125,45 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
- # Generate Notes Embeddings from Test Notes
- content_config.org = TextContentConfig(
+ LocalOrgConfig.objects.create(
input_files=None,
input_filter=["tests/data/org/*.org"],
- compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
- embeddings_file=content_dir.joinpath("note_embeddings.pt"),
+ index_heading_entries=False,
+ user=default_user,
)
- filters = [DateFilter(), WordFilter(), FileFilter()]
- text_search.setup(
- OrgToJsonl,
- get_sample_data("org"),
- content_config.org,
- search_models.text_search.bi_encoder,
- regenerate=False,
- filters=filters,
- )
-
- content_config.plugins = {
- "plugin1": TextContentConfig(
- input_files=[content_dir.joinpath("notes.jsonl.gz")],
- input_filter=None,
- compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
- embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
- )
- }
+ text_search.setup(OrgToJsonl, get_sample_data("org"), regenerate=False, user=default_user)
if os.getenv("GITHUB_PAT_TOKEN"):
- content_config.github = GithubContentConfig(
- pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
- repos=[
- GithubRepoConfig(
- owner="khoj-ai",
- name="lantern",
- branch="master",
- )
- ],
- compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
- embeddings_file=content_dir.joinpath("github_embeddings.pt"),
+ GithubConfig.objects.create(
+ pat_token=os.getenv("GITHUB_PAT_TOKEN"),
+ user=default_user,
)
- content_config.plaintext = TextContentConfig(
+ GithubRepoConfig.objects.create(
+ owner="khoj-ai",
+ name="lantern",
+ branch="master",
+ github_config=GithubConfig.objects.get(user=default_user),
+ )
+
+ LocalPlaintextConfig.objects.create(
input_files=None,
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
- compressed_jsonl=content_dir.joinpath("plaintext.jsonl.gz"),
- embeddings_file=content_dir.joinpath("plaintext_embeddings.pt"),
- )
-
- content_config.github = GithubContentConfig(
- pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
- repos=[
- GithubRepoConfig(
- owner="khoj-ai",
- name="lantern",
- branch="master",
- )
- ],
- compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
- embeddings_file=content_dir.joinpath("github_embeddings.pt"),
- )
-
- filters = [DateFilter(), WordFilter(), FileFilter()]
- text_search.setup(
- JsonlToJsonl,
- None,
- content_config.plugins["plugin1"],
- search_models.text_search.bi_encoder,
- regenerate=False,
- filters=filters,
+ user=default_user,
)
return content_config
@pytest.fixture(scope="session")
-def md_content_config(tmp_path_factory):
- content_dir = tmp_path_factory.mktemp("content")
-
- # Generate Embeddings for Markdown Content
- content_config = ContentConfig()
- content_config.markdown = TextContentConfig(
+def md_content_config():
+ markdown_config = LocalMarkdownConfig.objects.create(
input_files=None,
input_filter=["tests/data/markdown/*.markdown"],
- compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"),
- embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
)
- return content_config
+ return markdown_config
@pytest.fixture(scope="session")
@@ -220,19 +204,20 @@ def processor_config_offline_chat(tmp_path_factory):
@pytest.fixture(scope="session")
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
# Initialize app state
- state.config.content_type = md_content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
# Index Markdown Content for Search
- state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
- all_files = fs_syncer.collect_files(state.config.content_type)
+ all_files = fs_syncer.collect_files()
state.content_index = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models
)
# Initialize Processor from Config
state.processor_config = configure_processor(processor_config)
+ state.anonymous_mode = True
+
+ app = FastAPI()
configure_routes(app)
configure_middleware(app)
@@ -241,33 +226,45 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
@pytest.fixture(scope="function")
-def client(content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
+def fastapi_app():
+ app = FastAPI()
+ configure_routes(app)
+ configure_middleware(app)
+ app.mount("/static", StaticFiles(directory=web_directory), name="static")
+ return app
+
+
+@pytest.fixture(scope="function")
+def client(
+ content_config: ContentConfig,
+ search_config: SearchConfig,
+ processor_config: ProcessorConfig,
+ default_user: KhojUser,
+):
state.config.content_type = content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
# These lines help us Mock the Search models for these search types
- state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
- state.content_index.org = text_search.setup(
+ text_search.setup(
OrgToJsonl,
get_sample_data("org"),
- content_config.org,
- state.search_models.text_search.bi_encoder,
regenerate=False,
+ user=default_user,
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
- state.content_index.plaintext = text_search.setup(
+ text_search.setup(
PlaintextToJsonl,
get_sample_data("plaintext"),
- content_config.plaintext,
- state.search_models.text_search.bi_encoder,
regenerate=False,
+ user=default_user,
)
state.processor_config = configure_processor(processor_config)
+ state.anonymous_mode = True
configure_routes(app)
configure_middleware(app)
@@ -288,7 +285,6 @@ def client_offline_chat(
state.SearchType = configure_search_types(state.config)
# Index Markdown Content for Search
- state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
all_files = fs_syncer.collect_files(state.config.content_type)
@@ -298,6 +294,7 @@ def client_offline_chat(
# Initialize Processor from Config
state.processor_config = configure_processor(processor_config_offline_chat)
+ state.anonymous_mode = True
configure_routes(app)
configure_middleware(app)
@@ -306,9 +303,11 @@ def client_offline_chat(
@pytest.fixture(scope="function")
-def new_org_file(content_config: ContentConfig):
+def new_org_file(default_user: KhojUser, content_config: ContentConfig):
# Setup
- new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org"
+ org_config = LocalOrgConfig.objects.filter(user=default_user).first()
+ input_filters = org_config.input_filter
+ new_org_file = Path(input_filters[0]).parent / "new_file.org"
new_org_file.touch()
yield new_org_file
@@ -319,11 +318,9 @@ def new_org_file(content_config: ContentConfig):
@pytest.fixture(scope="function")
-def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path):
- new_org_config = deepcopy(content_config.org)
- new_org_config.input_files = [f"{new_org_file}"]
- new_org_config.input_filter = None
- return new_org_config
+def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
+ LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None)
+ return LocalOrgConfig.objects.filter(user=default_user).first()
@pytest.fixture(scope="function")
diff --git a/tests/data/config.yml b/tests/data/config.yml
index 06978cf1..c544eebe 100644
--- a/tests/data/config.yml
+++ b/tests/data/config.yml
@@ -9,17 +9,6 @@ content-type:
input-filter:
- '*.org'
- ~/notes/*.org
- plugins:
- content_plugin_1:
- compressed-jsonl: content_plugin_1.jsonl.gz
- embeddings-file: content_plugin_1_embeddings.pt
- input-files:
- - content_plugin_1_new.jsonl.gz
- content_plugin_2:
- compressed-jsonl: content_plugin_2.jsonl.gz
- embeddings-file: content_plugin_2_embeddings.pt
- input-filter:
- - '*2_new.jsonl.gz'
enable-offline-chat: false
search-type:
asymmetric:
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 9de3a853..cff2a7f3 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -48,14 +48,3 @@ def test_cli_config_from_file():
Path("~/first_from_config.org"),
Path("~/second_from_config.org"),
]
- assert len(actual_args.config.content_type.plugins.keys()) == 2
- assert actual_args.config.content_type.plugins["content_plugin_1"].input_files == [
- Path("content_plugin_1_new.jsonl.gz")
- ]
- assert actual_args.config.content_type.plugins["content_plugin_2"].input_filter == ["*2_new.jsonl.gz"]
- assert actual_args.config.content_type.plugins["content_plugin_1"].compressed_jsonl == Path(
- "content_plugin_1.jsonl.gz"
- )
- assert actual_args.config.content_type.plugins["content_plugin_2"].embeddings_file == Path(
- "content_plugin_2_embeddings.pt"
- )
diff --git a/tests/test_client.py b/tests/test_client.py
index a5f14882..f63b968c 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -2,22 +2,21 @@
from io import BytesIO
from PIL import Image
from urllib.parse import quote
-
+import pytest
# External Packages
from fastapi.testclient import TestClient
-import pytest
+from fastapi import FastAPI
# Internal Packages
-from app.main import app
from khoj.configure import configure_routes, configure_search_types
from khoj.utils import state
from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
-from khoj.search_filter.word_filter import WordFilter
-from khoj.search_filter.file_filter import FileFilter
+from database.models import KhojUser
+from database.adapters import EmbeddingsAdapters
# Test
@@ -35,7 +34,7 @@ def test_search_with_invalid_content_type(client):
# ----------------------------------------------------------------------------------------------------
def test_search_with_valid_content_type(client):
- for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plugin1"]:
+ for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion"]:
# Act
response = client.get(f"/api/search?q=random&t={content_type}")
# Assert
@@ -75,7 +74,7 @@ def test_index_update(client):
# ----------------------------------------------------------------------------------------------------
def test_regenerate_with_valid_content_type(client):
- for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]:
+ for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]:
# Arrange
files = get_sample_files_data()
headers = {"x-api-key": "secret"}
@@ -102,60 +101,42 @@ def test_regenerate_with_github_fails_without_pat(client):
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db
@pytest.mark.skip(reason="Flaky test on parallel test runs")
-def test_get_configured_types_via_api(client):
+def test_get_configured_types_via_api(client, sample_org_data):
# Act
- response = client.get(f"/api/config/types")
+ text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
+
+ enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
# Assert
- assert response.status_code == 200
- assert response.json() == ["all", "org", "image", "plaintext", "plugin1"]
+ assert list(enabled_types) == ["org"]
# ----------------------------------------------------------------------------------------------------
-def test_get_configured_types_with_only_plugin_content_config(content_config):
+@pytest.mark.django_db(transaction=True)
+def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data):
# Arrange
- config.content_type = ContentConfig()
- config.content_type.plugins = content_config.plugins
- state.SearchType = configure_search_types(config)
-
- configure_routes(app)
- client = TestClient(app)
+ text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
# Act
response = client.get(f"/api/config/types")
# Assert
assert response.status_code == 200
- assert response.json() == ["all", "plugin1"]
+ assert response.json() == ["all", "org", "markdown", "image"]
# ----------------------------------------------------------------------------------------------------
-def test_get_configured_types_with_no_plugin_content_config(content_config):
+@pytest.mark.django_db(transaction=True)
+def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
# Arrange
- config.content_type = content_config
- config.content_type.plugins = None
state.SearchType = configure_search_types(config)
+ original_config = state.config.content_type
+ state.config.content_type = None
- configure_routes(app)
- client = TestClient(app)
-
- # Act
- response = client.get(f"/api/config/types")
-
- # Assert
- assert response.status_code == 200
- assert "plugin1" not in response.json()
-
-
-# ----------------------------------------------------------------------------------------------------
-def test_get_configured_types_with_no_content_config():
- # Arrange
- config.content_type = ContentConfig()
- state.SearchType = configure_search_types(config)
-
- configure_routes(app)
- client = TestClient(app)
+ configure_routes(fastapi_app)
+ client = TestClient(fastapi_app)
# Act
response = client.get(f"/api/config/types")
@@ -164,6 +145,9 @@ def test_get_configured_types_with_no_content_config():
assert response.status_code == 200
assert response.json() == ["all"]
+ # Restore
+ state.config.content_type = original_config
+
# ----------------------------------------------------------------------------------------------------
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
@@ -192,12 +176,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
-def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data):
+@pytest.mark.django_db(transaction=True)
+def test_notes_search(client, search_config: SearchConfig, sample_org_data):
# Arrange
- search_models.text_search = text_search.initialize_model(search_config.asymmetric)
- content_index.org = text_search.setup(
- OrgToJsonl, sample_org_data, content_config.org, search_models.text_search.bi_encoder, regenerate=False
- )
+ text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
user_query = quote("How to git install application?")
# Act
@@ -211,19 +193,15 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
def test_notes_search_with_only_filters(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
):
# Arrange
- filters = [WordFilter(), FileFilter()]
- search_models.text_search = text_search.initialize_model(search_config.asymmetric)
- content_index.org = text_search.setup(
+ text_search.setup(
OrgToJsonl,
sample_org_data,
- content_config.org,
- search_models.text_search.bi_encoder,
regenerate=False,
- filters=filters,
)
user_query = quote('+"Emacs" file:"*.org"')
@@ -238,15 +216,10 @@ def test_notes_search_with_only_filters(
# ----------------------------------------------------------------------------------------------------
-def test_notes_search_with_include_filter(
- client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
-):
+@pytest.mark.django_db(transaction=True)
+def test_notes_search_with_include_filter(client, sample_org_data):
# Arrange
- filters = [WordFilter()]
- search_models.text_search = text_search.initialize_model(search_config.asymmetric)
- content_index.org = text_search.setup(
- OrgToJsonl, sample_org_data, content_config.org, search_models.text_search, regenerate=False, filters=filters
- )
+ text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
user_query = quote('How to git install application? +"Emacs"')
# Act
@@ -260,19 +233,13 @@ def test_notes_search_with_include_filter(
# ----------------------------------------------------------------------------------------------------
-def test_notes_search_with_exclude_filter(
- client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
-):
+@pytest.mark.django_db(transaction=True)
+def test_notes_search_with_exclude_filter(client, sample_org_data):
# Arrange
- filters = [WordFilter()]
- search_models.text_search = text_search.initialize_model(search_config.asymmetric)
- content_index.org = text_search.setup(
+ text_search.setup(
OrgToJsonl,
sample_org_data,
- content_config.org,
- search_models.text_search.bi_encoder,
regenerate=False,
- filters=filters,
)
user_query = quote('How to git install application? -"clone"')
@@ -286,6 +253,22 @@ def test_notes_search_with_exclude_filter(
assert "clone" not in search_result
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
+def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
+ # Arrange
+ text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
+ user_query = quote("How to git install application?")
+
+ # Act
+ response = client.get(f"/api/search?q={user_query}&n=1&t=org")
+
+ # Assert
+ assert response.status_code == 200
+ # assert actual response has no data as the default_user is different from the user making the query (anonymous)
+ assert len(response.json()) == 0
+
+
def get_sample_files_data():
return {
"files": ("path/to/filename.org", "* practicing piano", "text/org"),
diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py
index d0d05bc5..f1f26d28 100644
--- a/tests/test_date_filter.py
+++ b/tests/test_date_filter.py
@@ -1,53 +1,12 @@
# Standard Packages
import re
from datetime import datetime
-from math import inf
# External Packages
import pytest
# Internal Packages
from khoj.search_filter.date_filter import DateFilter
-from khoj.utils.rawconfig import Entry
-
-
-@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
-def test_date_filter():
- entries = [
- Entry(compiled="Entry with no date", raw="Entry with no date"),
- Entry(compiled="April Fools entry: 1984-04-01", raw="April Fools entry: 1984-04-01"),
- Entry(compiled="Entry with date:1984-04-02", raw="Entry with date:1984-04-02"),
- ]
-
- q_with_no_date_filter = "head tail"
- ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
- assert ret_query == "head tail"
- assert entry_indices == {0, 1, 2}
-
- q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
- ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
- assert ret_query == "head tail"
- assert entry_indices == set()
-
- query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
- ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
- assert ret_query == "head tail"
- assert entry_indices == {2}
-
- query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
- ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
- assert ret_query == "head tail"
- assert entry_indices == {1}
-
- query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
- ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
- assert ret_query == "head tail"
- assert entry_indices == {2}
-
- query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
- ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
- assert ret_query == "head tail"
- assert entry_indices == {1, 2}
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
@@ -56,8 +15,8 @@ def test_extract_date_range():
datetime(1984, 1, 5, 0, 0, 0).timestamp(),
datetime(1984, 1, 7, 0, 0, 0).timestamp(),
]
- assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
- assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
+ assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [None, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
+ assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), None]
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
datetime(1984, 1, 1, 0, 0, 0).timestamp(),
datetime(1984, 1, 2, 0, 0, 0).timestamp(),
diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py
index ed632d32..f5a903f8 100644
--- a/tests/test_file_filter.py
+++ b/tests/test_file_filter.py
@@ -6,97 +6,73 @@ from khoj.utils.rawconfig import Entry
def test_no_file_filter():
# Arrange
file_filter = FileFilter()
- entries = arrange_content()
q_with_no_filter = "head tail"
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
- ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == False
- assert ret_query == "head tail"
- assert entry_indices == {0, 1, 2, 3}
def test_file_filter_with_non_existent_file():
# Arrange
file_filter = FileFilter()
- entries = arrange_content()
q_with_no_filter = 'head file:"nonexistent.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
- ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
- assert ret_query == "head tail"
- assert entry_indices == {}
def test_single_file_filter():
# Arrange
file_filter = FileFilter()
- entries = arrange_content()
q_with_no_filter = 'head file:"file 1.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
- ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
- assert ret_query == "head tail"
- assert entry_indices == {0, 2}
def test_file_filter_with_partial_match():
# Arrange
file_filter = FileFilter()
- entries = arrange_content()
q_with_no_filter = 'head file:"1.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
- ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
- assert ret_query == "head tail"
- assert entry_indices == {0, 2}
def test_file_filter_with_regex_match():
# Arrange
file_filter = FileFilter()
- entries = arrange_content()
q_with_no_filter = 'head file:"*.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
- ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
- assert ret_query == "head tail"
- assert entry_indices == {0, 1, 2, 3}
def test_multiple_file_filter():
# Arrange
file_filter = FileFilter()
- entries = arrange_content()
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
- ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
- assert ret_query == "head tail"
- assert entry_indices == {0, 1, 2, 3}
def test_get_file_filter_terms():
@@ -108,7 +84,7 @@ def test_get_file_filter_terms():
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
# Assert
- assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"']
+ assert filter_terms == ["file 1\\.org", "/path/to/dir/.*\\.org"]
def arrange_content():
diff --git a/tests/test_jsonl_to_jsonl.py b/tests/test_jsonl_to_jsonl.py
deleted file mode 100644
index b52b5fc9..00000000
--- a/tests/test_jsonl_to_jsonl.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# Internal Packages
-from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
-from khoj.utils.rawconfig import Entry
-
-
-def test_process_entries_from_single_input_jsonl(tmp_path):
- "Convert multiple jsonl entries from single file to entries."
- # Arrange
- input_jsonl = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}
-{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}
-"""
- input_jsonl_file = create_file(tmp_path, input_jsonl)
-
- # Act
- # Process Each Entry from All Notes Files
- input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file])
- entries = list(map(Entry.from_dict, input_jsons))
- output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
-
- # Assert
- assert len(entries) == 2
- assert output_jsonl == input_jsonl
-
-
-def test_process_entries_from_multiple_input_jsonls(tmp_path):
- "Convert multiple jsonl entries from single file to entries."
- # Arrange
- input_jsonl_1 = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}"""
- input_jsonl_2 = """{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}"""
- input_jsonl_file_1 = create_file(tmp_path, input_jsonl_1, filename="input1.jsonl")
- input_jsonl_file_2 = create_file(tmp_path, input_jsonl_2, filename="input2.jsonl")
-
- # Act
- # Process Each Entry from All Notes Files
- input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file_1, input_jsonl_file_2])
- entries = list(map(Entry.from_dict, input_jsons))
- output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
-
- # Assert
- assert len(entries) == 2
- assert output_jsonl == f"{input_jsonl_1}\n{input_jsonl_2}\n"
-
-
-def test_get_jsonl_files(tmp_path):
- "Ensure JSONL files specified via input-filter, input-files extracted"
- # Arrange
- # Include via input-filter globs
- group1_file1 = create_file(tmp_path, filename="group1-file1.jsonl")
- group1_file2 = create_file(tmp_path, filename="group1-file2.jsonl")
- group2_file1 = create_file(tmp_path, filename="group2-file1.jsonl")
- group2_file2 = create_file(tmp_path, filename="group2-file2.jsonl")
- # Include via input-file field
- file1 = create_file(tmp_path, filename="notes.jsonl")
- # Not included by any filter
- create_file(tmp_path, filename="not-included-jsonl.jsonl")
- create_file(tmp_path, filename="not-included-text.txt")
-
- expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
-
- # Setup input-files, input-filters
- input_files = [tmp_path / "notes.jsonl"]
- input_filter = [tmp_path / "group1*.jsonl", tmp_path / "group2*.jsonl"]
-
- # Act
- extracted_org_files = JsonlToJsonl.get_jsonl_files(input_files, input_filter)
-
- # Assert
- assert len(extracted_org_files) == 5
- assert extracted_org_files == expected_files
-
-
-# Helper Functions
-def create_file(tmp_path, entry=None, filename="test.jsonl"):
- jsonl_file = tmp_path / filename
- jsonl_file.touch()
- if entry:
- jsonl_file.write_text(entry)
- return jsonl_file
diff --git a/tests/test_org_to_jsonl.py b/tests/test_org_to_jsonl.py
index abf20d09..d47c212e 100644
--- a/tests/test_org_to_jsonl.py
+++ b/tests/test_org_to_jsonl.py
@@ -4,7 +4,7 @@ import os
# Internal Packages
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
-from khoj.processor.text_to_jsonl import TextToJsonl
+from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import Entry
from khoj.utils.fs_syncer import get_org_files
@@ -63,7 +63,7 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
# Split each entry from specified Org files by max words
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
- TextToJsonl.split_entries_by_max_tokens(
+ TextEmbeddings.split_entries_by_max_tokens(
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
)
)
@@ -86,7 +86,7 @@ def test_entry_split_drops_large_words():
# Act
# Split entry by max words and drop words larger than max word length
- processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0]
+ processed_entry = TextEmbeddings.split_entries_by_max_tokens([entry], max_word_length=5)[0]
# Assert
# "Heading" dropped from compiled version because its over the set max word limit
diff --git a/tests/test_plaintext_to_jsonl.py b/tests/test_plaintext_to_jsonl.py
index a6da30e1..56c68e38 100644
--- a/tests/test_plaintext_to_jsonl.py
+++ b/tests/test_plaintext_to_jsonl.py
@@ -7,6 +7,7 @@ from pathlib import Path
from khoj.utils.fs_syncer import get_plaintext_files
from khoj.utils.rawconfig import TextContentConfig
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
+from database.models import LocalPlaintextConfig, KhojUser
def test_plaintext_file(tmp_path):
@@ -91,11 +92,12 @@ def test_get_plaintext_files(tmp_path):
assert set(extracted_plaintext_files.keys()) == set(expected_files)
-def test_parse_html_plaintext_file(content_config):
+def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
"Ensure HTML files are parsed correctly"
# Arrange
# Setup input-files, input-filters
- extracted_plaintext_files = get_plaintext_files(content_config.plaintext)
+ config = LocalPlaintextConfig.objects.filter(user=default_user).first()
+ extracted_plaintext_files = get_plaintext_files(config=config)
# Act
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files)
diff --git a/tests/test_text_search.py b/tests/test_text_search.py
index 179718fa..af47ffe5 100644
--- a/tests/test_text_search.py
+++ b/tests/test_text_search.py
@@ -3,23 +3,30 @@ import logging
import locale
from pathlib import Path
import os
+import asyncio
# External Packages
import pytest
# Internal Packages
-from khoj.utils.state import content_index, search_models
from khoj.search_type import text_search
+from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.utils.config import SearchModels
-from khoj.utils.fs_syncer import get_org_files
+from khoj.utils.fs_syncer import get_org_files, collect_files
+from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
+
+logger = logging.getLogger(__name__)
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
# Test
# ----------------------------------------------------------------------------------------------------
-def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig):
+@pytest.mark.django_db
+def test_text_search_setup_with_missing_file_raises_error(
+ org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig
+):
# Arrange
# Ensure file mentioned in org.input-files is missing
single_new_file = Path(org_config_with_only_new_file.input_files[0])
@@ -32,98 +39,126 @@ def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_n
# ----------------------------------------------------------------------------------------------------
-def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path: Path):
+@pytest.mark.django_db
+def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser):
# Arrange
orgfile = tmp_path / "directory.org" / "file.org"
orgfile.parent.mkdir()
with open(orgfile, "w") as f:
f.write("* Heading\n- List item\n")
- org_content_config = TextContentConfig(
- input_filter=[f"{tmp_path}/**/*"], compressed_jsonl="test.jsonl", embeddings_file="test.pt"
+
+ LocalOrgConfig.objects.create(
+ input_filter=[f"{tmp_path}/**/*"],
+ input_files=None,
+ user=default_user,
)
+ org_files = collect_files(user=default_user)["org"]
+
# Act
# should not raise IsADirectoryError and return orgfile
- assert get_org_files(org_content_config) == {f"{orgfile}": "* Heading\n- List item\n"}
+ assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db
def test_text_search_setup_with_empty_file_raises_error(
- org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
+ org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange
data = get_org_files(org_config_with_only_new_file)
# Act
# Generate notes embeddings during asymmetric setup
- with pytest.raises(ValueError, match=r"^No valid entries found*"):
- text_search.setup(OrgToJsonl, data, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
+
+ assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
+ verify_embeddings(0, default_user)
# ----------------------------------------------------------------------------------------------------
-def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
+@pytest.mark.django_db
+def test_text_search_setup(content_config, default_user: KhojUser, caplog):
# Arrange
- data = get_org_files(content_config.org)
-
- # Act
- # Regenerate notes embeddings during asymmetric setup
- notes_model = text_search.setup(
- OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
- )
+ org_config = LocalOrgConfig.objects.filter(user=default_user).first()
+ data = get_org_files(org_config)
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# Assert
- assert len(notes_model.entries) == 10
- assert len(notes_model.corpus_embeddings) == 10
+ assert "Deleting all embeddings for file type org" in caplog.records[1].message
+ assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
# ----------------------------------------------------------------------------------------------------
-def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog):
+@pytest.mark.django_db
+def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
# Arrange
- caplog.set_level(logging.INFO, logger="khoj")
-
- data = get_org_files(content_config.org)
+ org_config = LocalOrgConfig.objects.filter(user=default_user).first()
+ data = get_org_files(org_config)
# Act
# Generate initial notes embeddings during asymmetric setup
- text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
- text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert
- assert "Creating index from scratch." in initial_logs
- assert "Creating index from scratch." not in final_logs
+ assert "Deleting all embeddings for file type org" in initial_logs
+ assert "Deleting all embeddings for file type org" not in final_logs
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db
@pytest.mark.anyio
-async def test_text_search(content_config: ContentConfig, search_config: SearchConfig):
+# @pytest.mark.asyncio
+async def test_text_search(search_config: SearchConfig):
# Arrange
- data = get_org_files(content_config.org)
-
- search_models.text_search = text_search.initialize_model(search_config.asymmetric)
- content_index.org = text_search.setup(
- OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
+ default_user = await KhojUser.objects.acreate(
+ username="test_user", password="test_password", email="test@example.com"
)
+ # Arrange
+ org_config = await LocalOrgConfig.objects.acreate(
+ input_files=None,
+ input_filter=["tests/data/org/*.org"],
+ index_heading_entries=False,
+ user=default_user,
+ )
+ data = get_org_files(org_config)
+
+ loop = asyncio.get_event_loop()
+ await loop.run_in_executor(
+ None,
+ text_search.setup,
+ OrgToJsonl,
+ data,
+ True,
+ True,
+ default_user,
+ )
+
query = "How to git install application?"
# Act
- hits, entries = await text_search.query(
- query, search_model=search_models.text_search, content=content_index.org, rank_results=True
- )
-
- results = text_search.collate_results(hits, entries, count=1)
+ hits = await text_search.query(default_user, query)
# Assert
+ results = text_search.collate_results(hits)
+ results = sorted(results, key=lambda x: float(x.score))[:1]
# search results should contain "git clone" entry
search_result = results[0].entry
assert "git clone" in search_result
# ----------------------------------------------------------------------------------------------------
-def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
+@pytest.mark.django_db
+def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256
@@ -137,47 +172,46 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# Act
# reload embeddings, entries, notes model after adding new org-mode file
- initial_notes_model = text_search.setup(
- OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
- )
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
# Assert
# verify newly added org-mode entry is split by max tokens
- assert len(initial_notes_model.entries) == 2
- assert len(initial_notes_model.corpus_embeddings) == 2
+ record = caplog.records[1]
+ assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
# ----------------------------------------------------------------------------------------------------
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
+@pytest.mark.django_db
def test_entry_chunking_by_max_tokens_not_full_corpus(
- org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
+ org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
data = {
"readme.org": """
* Khoj
- /Allow natural language search on user content like notes, images using transformer based models/
+/Allow natural language search on user content like notes, images using transformer based models/
- All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
+All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
** Dependencies
- - Python3
- - [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
+- Python3
+- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
** Install
- #+begin_src shell
- git clone https://github.com/khoj-ai/khoj && cd khoj
- conda env create -f environment.yml
- conda activate khoj
- #+end_src"""
+#+begin_src shell
+git clone https://github.com/khoj-ai/khoj && cd khoj
+conda env create -f environment.yml
+conda activate khoj
+#+end_src"""
}
text_search.setup(
OrgToJsonl,
data,
- org_config_with_only_new_file,
- search_models.text_search.bi_encoder,
regenerate=False,
+ user=default_user,
)
max_tokens = 256
@@ -191,64 +225,57 @@ def test_entry_chunking_by_max_tokens_not_full_corpus(
# Act
# reload embeddings, entries, notes model after adding new org-mode file
- initial_notes_model = text_search.setup(
- OrgToJsonl,
- data,
- org_config_with_only_new_file,
- search_models.text_search.bi_encoder,
- regenerate=False,
- full_corpus=False,
- )
+ with caplog.at_level(logging.INFO):
+ text_search.setup(
+ OrgToJsonl,
+ data,
+ regenerate=False,
+ full_corpus=False,
+ user=default_user,
+ )
+
+ record = caplog.records[1]
# Assert
# verify newly added org-mode entry is split by max tokens
- assert len(initial_notes_model.entries) == 5
- assert len(initial_notes_model.corpus_embeddings) == 5
+ assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db
def test_regenerate_index_with_new_entry(
- content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
+ content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
):
# Arrange
- data = get_org_files(content_config.org)
- initial_notes_model = text_search.setup(
- OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
- )
+ org_config = LocalOrgConfig.objects.filter(user=default_user).first()
+ data = get_org_files(org_config)
- assert len(initial_notes_model.entries) == 10
- assert len(initial_notes_model.corpus_embeddings) == 10
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
+
+ assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
# append org-mode entry to first org input file in config
- content_config.org.input_files = [f"{new_org_file}"]
+ org_config.input_files = [f"{new_org_file}"]
with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
- data = get_org_files(content_config.org)
+ data = get_org_files(org_config)
# Act
# regenerate notes jsonl, model embeddings and model to include entry from new file
- regenerated_notes_model = text_search.setup(
- OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
- )
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# Assert
- assert len(regenerated_notes_model.entries) == 11
- assert len(regenerated_notes_model.corpus_embeddings) == 11
-
- # verify new entry appended to index, without disrupting order or content of existing entries
- error_details = compare_index(initial_notes_model, regenerated_notes_model)
- if error_details:
- pytest.fail(error_details, False)
-
- # Cleanup
- # reset input_files in config to empty list
- content_config.org.input_files = []
+ assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message
+ verify_embeddings(11, default_user)
# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db
def test_update_index_with_duplicate_entries_in_stable_order(
- org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
+ org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
@@ -262,30 +289,26 @@ def test_update_index_with_duplicate_entries_in_stable_order(
# Act
# load embeddings, entries, notes model after adding new org-mode file
- initial_index = text_search.setup(
- OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
- )
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model after adding new org-mode file
- updated_index = text_search.setup(
- OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
- )
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
- assert len(initial_index.entries) == len(updated_index.entries) == 1
- assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1
+ assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
+ assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
- # verify the same entry is added even when there are multiple duplicate entries
- error_details = compare_index(initial_index, updated_index)
- if error_details:
- pytest.fail(error_details)
+ verify_embeddings(1, default_user)
# ----------------------------------------------------------------------------------------------------
-def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
+@pytest.mark.django_db
+def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
# Arrange
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
@@ -296,9 +319,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
data = get_org_files(org_config_with_only_new_file)
# load embeddings, entries, notes model after adding new org file with 2 entries
- initial_index = text_search.setup(
- OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
- )
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f:
@@ -307,87 +329,65 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
data = get_org_files(org_config_with_only_new_file)
# Act
- updated_index = text_search.setup(
- OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
- )
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
- assert len(initial_index.entries) == len(updated_index.entries) + 1
- assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
+ assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
+ assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message
- # verify the same entry is added even when there are multiple duplicate entries
- error_details = compare_index(updated_index, initial_index)
- if error_details:
- pytest.fail(error_details)
+ verify_embeddings(1, default_user)
# ----------------------------------------------------------------------------------------------------
-def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
+@pytest.mark.django_db
+def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
# Arrange
- data = get_org_files(content_config.org)
- initial_notes_model = text_search.setup(
- OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
- )
+ org_config = LocalOrgConfig.objects.filter(user=default_user).first()
+ data = get_org_files(org_config)
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# append org-mode entry to first org input file in config
with open(new_org_file, "w") as f:
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
f.write(new_entry)
- data = get_org_files(content_config.org)
+ data = get_org_files(org_config)
# Act
# update embeddings, entries with the newly added note
- content_config.org.input_files = [f"{new_org_file}"]
- final_notes_model = text_search.setup(
- OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
- )
+ with caplog.at_level(logging.INFO):
+ text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
# Assert
- assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1
- assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1
+ assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
+ assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
- # verify new entry appended to index, without disrupting order or content of existing entries
- error_details = compare_index(initial_notes_model, final_notes_model)
- if error_details:
- pytest.fail(error_details, False)
-
- # Cleanup
- # reset input_files in config to empty list
- content_config.org.input_files = []
+ verify_embeddings(11, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
-def test_text_search_setup_github(content_config: ContentConfig, search_models: SearchModels):
+def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
+ # Arrange
+ github_config = GithubConfig.objects.filter(user=default_user).first()
# Act
# Regenerate github embeddings to test asymmetric setup without caching
- github_model = text_search.setup(
- GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
+ text_search.setup(
+ GithubToJsonl,
+ {},
+ regenerate=True,
+ user=default_user,
+ config=github_config,
)
# Assert
- assert len(github_model.entries) > 1
+ embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count()
+ assert embeddings > 1
-def compare_index(initial_notes_model, final_notes_model):
- mismatched_entries, mismatched_embeddings = [], []
- for index in range(len(initial_notes_model.entries)):
- if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
- mismatched_entries.append(index)
-
- # verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
- for index in range(len(initial_notes_model.corpus_embeddings)):
- if not initial_notes_model.corpus_embeddings[index].allclose(final_notes_model.corpus_embeddings[index]):
- mismatched_embeddings.append(index)
-
- error_details = ""
- if mismatched_entries:
- mismatched_entries_str = ",".join(map(str, mismatched_entries))
- error_details += f"Entries at {mismatched_entries_str} not equal\n"
- if mismatched_embeddings:
- mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
- error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
-
- return error_details
+def verify_embeddings(expected_count, user):
+ embeddings = Embeddings.objects.filter(user=user, file_type="org").count()
+ assert embeddings == expected_count
diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py
index 04f45506..2ede35e7 100644
--- a/tests/test_word_filter.py
+++ b/tests/test_word_filter.py
@@ -3,68 +3,40 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.utils.rawconfig import Entry
-def test_no_word_filter():
- # Arrange
- word_filter = WordFilter()
- entries = arrange_content()
- q_with_no_filter = "head tail"
-
- # Act
- can_filter = word_filter.can_filter(q_with_no_filter)
- ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries)
-
- # Assert
- assert can_filter == False
- assert ret_query == "head tail"
- assert entry_indices == {0, 1, 2, 3}
-
-
def test_word_exclude_filter():
# Arrange
word_filter = WordFilter()
- entries = arrange_content()
q_with_exclude_filter = 'head -"exclude_word" tail'
# Act
can_filter = word_filter.can_filter(q_with_exclude_filter)
- ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
# Assert
assert can_filter == True
- assert ret_query == "head tail"
- assert entry_indices == {0, 2}
def test_word_include_filter():
# Arrange
word_filter = WordFilter()
- entries = arrange_content()
query_with_include_filter = 'head +"include_word" tail'
# Act
can_filter = word_filter.can_filter(query_with_include_filter)
- ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
# Assert
assert can_filter == True
- assert ret_query == "head tail"
- assert entry_indices == {2, 3}
def test_word_include_and_exclude_filter():
# Arrange
word_filter = WordFilter()
- entries = arrange_content()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
- ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
# Assert
assert can_filter == True
- assert ret_query == "head tail"
- assert entry_indices == {2}
def test_get_word_filter_terms():