Make Search Model Configurable on Server (#544)

- Make search model configurable on server
- Update migration script to get search model from `khoj.yml` to Postgres
- Update first run message on Khoj Desktop and Web app landing page
- Other miscellaneous bug fixes
This commit is contained in:
Debanjum 2023-11-16 00:11:58 -08:00 committed by GitHub
commit 208ddddc6a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 263 additions and 84 deletions

43
.github/workflows/dockerize_dev.yml vendored Normal file
View file

@ -0,0 +1,43 @@
name: dockerize-dev
on:
pull_request:
paths:
- src/khoj/**
- config/**
- pyproject.toml
- prod.Dockerfile
- .github/workflows/dockerize_dev.yml
workflow_dispatch:
env:
DOCKER_IMAGE_TAG: 'dev'
jobs:
build:
name: Build Production Docker Image, Push to Container Registry
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Login to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.PAT }}
- name: 📦 Build and Push Docker Image
uses: docker/build-push-action@v2
with:
context: .
file: prod.Dockerfile
platforms: linux/amd64
push: true
tags: ghcr.io/${{ github.repository }}:${{ env.DOCKER_IMAGE_TAG }}
build-args: |
PORT=42110

View file

@ -54,7 +54,7 @@ dependencies = [
"transformers >= 4.28.0",
"torch == 2.0.1",
"uvicorn == 0.17.6",
"aiohttp == 3.8.5",
"aiohttp == 3.8.6",
"langchain >= 0.0.331",
"requests >= 2.26.0",
"bs4 >= 0.0.1",

View file

@ -1,8 +1,8 @@
import math
from typing import Optional, Type, TypeVar, List
from datetime import date, datetime, timedelta
from typing import Optional, Type, List
from datetime import date, datetime
import secrets
from typing import Type, TypeVar, List
from typing import Type, List
from datetime import date, timezone
from django.db import models
@ -11,10 +11,6 @@ from pgvector.django import CosineDistance
from django.db.models.manager import BaseManager
from django.db.models import Q
from torch import Tensor
from pgvector.django import CosineDistance
from django.db.models.manager import BaseManager
from django.db.models import Q
from torch import Tensor
# Import sync_to_async from Django Channels
from asgiref.sync import sync_to_async
@ -31,6 +27,7 @@ from database.models import (
GithubRepoConfig,
Conversation,
ChatModelOptions,
SearchModelConfig,
Subscription,
UserConversationConfig,
OpenAIProcessorConversationConfig,
@ -41,15 +38,6 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.date_filter import DateFilter
ModelType = TypeVar("ModelType", bound=models.Model)
async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType:
instance = await model_class.objects.filter(id=id).afirst()
if not instance:
raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found")
return instance
async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst()
@ -65,9 +53,7 @@ async def create_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}"
name = name or f"{generate_random_name().title()}"
api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name)
await api_config.asave()
return api_config
return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
def get_khoj_tokens(user: KhojUser):
@ -83,13 +69,16 @@ async def delete_khoj_token(user: KhojUser, token: str):
async def get_or_create_user(token: dict) -> KhojUser:
user = await get_user_by_token(token)
if not user:
user = await create_google_user(token)
user = await create_user_by_google_token(token)
return user
async def create_google_user(token: dict) -> KhojUser:
user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email"))
async def create_user_by_google_token(token: dict) -> KhojUser:
user, _ = await KhojUser.objects.filter(email=token.get("email")).aupdate_or_create(
defaults={"username": token.get("email"), "email": token.get("email")}
)
await user.asave()
await GoogleUser.objects.acreate(
sub=token.get("sub"),
azp=token.get("azp"),
@ -220,6 +209,14 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config
def get_or_create_search_model():
search_model = SearchModelConfig.objects.filter().first()
if not search_model:
search_model = SearchModelConfig.objects.create()
return search_model
class ConversationAdapters:
@staticmethod
def get_conversation_by_user(user: KhojUser):

View file

@ -8,6 +8,7 @@ from database.models import (
ChatModelOptions,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
SearchModelConfig,
Subscription,
)
@ -16,4 +17,5 @@ admin.site.register(KhojUser, UserAdmin)
admin.site.register(ChatModelOptions)
admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig)
admin.site.register(SearchModelConfig)
admin.site.register(Subscription)

View file

@ -0,0 +1,32 @@
# Generated by Django 4.2.5 on 2023-11-14 23:25
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0016_alter_subscription_renewal_date"),
]
operations = [
migrations.CreateModel(
name="SearchModel",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(default="default", max_length=200)),
("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
(
"cross_encoder",
models.CharField(
blank=True, default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200, null=True
),
),
],
options={
"abstract": False,
},
),
]

View file

@ -0,0 +1,30 @@
# Generated by Django 4.2.5 on 2023-11-16 01:13
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0017_searchmodel"),
]
operations = [
migrations.CreateModel(
name="SearchModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(default="default", max_length=200)),
("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
("cross_encoder", models.CharField(default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200)),
],
options={
"abstract": False,
},
),
migrations.DeleteModel(
name="SearchModel",
),
]

View file

@ -102,6 +102,16 @@ class LocalPlaintextConfig(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class SearchModelConfig(BaseModel):
class ModelType(models.TextChoices):
TEXT = "text"
name = models.CharField(max_length=200, default="default")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200)

View file

@ -328,7 +328,15 @@
.then(data => {
if (data.detail) {
// If the server returns a 500 error with detail, render a setup hint.
renderMessage("Hi 👋🏾, to get started you have two options:<ol><li><b>Use OpenAI</b>: <ol><li>Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a></li><li>Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a></li><li>Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a></li></ol></li><li><b>Enable offline chat</b>: <ol><li>Go to the Khoj <a class='inline-chat-link' href='/config'>settings page</a> and enable offline chat</li></ol></li></ol>", "khoj");
first_run_message = `Hi 👋🏾, to get started:
<ol>
<li>Generate an API token in the <a class='inline-chat-link' href="#" onclick="window.navigateAPI.navigateToWebSettings()">Khoj Web settings</a></li>
<li>Paste it into the API Key field in the <a class='inline-chat-link' href="#" onclick="window.navigateAPI.navigateToSettings()">Khoj Desktop settings</a></li>
</ol>`
.trim()
.replace(/(\r\n|\n|\r)/gm, "");
renderMessage(first_run_message, "khoj");
// Disable chat input field and update placeholder text
document.getElementById("chat-input").setAttribute("disabled", "disabled");

View file

@ -396,6 +396,14 @@ app.whenReady().then(() => {
event.reply('update-state', arg);
});
ipcMain.on('navigate', (event, page) => {
win.loadFile(page);
});
ipcMain.on('navigateToWebApp', (event, page) => {
shell.openExternal(`${store.get('hostURL')}/${page}`);
});
ipcMain.handle('getFiles', getFiles);
ipcMain.handle('getFolders', getFolders);

View file

@ -10,14 +10,14 @@
"main": "main.js",
"private": false,
"devDependencies": {
"electron": "25.8.1"
"electron": "25.8.4"
},
"scripts": {
"start": "yarn electron ."
},
"dependencies": {
"@todesktop/runtime": "^1.3.0",
"axios": "^1.5.0",
"axios": "^1.6.0",
"cron": "^2.4.3",
"electron-store": "^8.1.0",
"fs": "^0.0.1-security"

View file

@ -57,3 +57,8 @@ contextBridge.exposeInMainWorld('tokenAPI', {
contextBridge.exposeInMainWorld('appInfoAPI', {
getInfo: (callback) => ipcRenderer.on('appInfo', callback)
})
contextBridge.exposeInMainWorld('navigateAPI', {
navigateToSettings: () => ipcRenderer.send('navigate', 'config.html'),
navigateToWebSettings: () => ipcRenderer.send('navigateToWebApp', 'config'),
})

View file

@ -163,10 +163,10 @@ atomically@^1.7.0:
resolved "https://registry.yarnpkg.com/atomically/-/atomically-1.7.0.tgz#c07a0458432ea6dbc9a3506fffa424b48bccaafe"
integrity sha512-Xcz9l0z7y9yQ9rdDaxlmaI4uJHf/T8g9hOEzJcsEqX2SjCj4J20uK7+ldkDHMbpJDK76wF7xEIgxc/vSlsfw5w==
axios@^1.5.0:
version "1.5.0"
resolved "https://registry.yarnpkg.com/axios/-/axios-1.5.0.tgz#f02e4af823e2e46a9768cfc74691fdd0517ea267"
integrity sha512-D4DdjDo5CY50Qms0qGQTTw6Q44jl7zRwY7bthds06pUGfChBCTcQs+N743eFWGEd6pRTMd6A+I87aWyFV5wiZQ==
axios@^1.6.0:
version "1.6.2"
resolved "https://registry.yarnpkg.com/axios/-/axios-1.6.2.tgz#de67d42c755b571d3e698df1b6504cde9b0ee9f2"
integrity sha512-7i24Ri4pmDRfJTR7LDBhsOTtcm+9kjX5WiY1X3wIisx6G9So3pfMkEiU7emUBe46oceVImccTEM3k6C5dbVW8A==
dependencies:
follow-redirects "^1.15.0"
form-data "^4.0.0"
@ -379,10 +379,10 @@ electron-updater@^4.6.1:
lodash.isequal "^4.5.0"
semver "^7.3.5"
electron@25.8.1:
version "25.8.1"
resolved "https://registry.yarnpkg.com/electron/-/electron-25.8.1.tgz#092fab5a833db4d9240d4d6f36218cf7ca954f86"
integrity sha512-GtcP1nMrROZfFg0+mhyj1hamrHvukfF6of2B/pcWxmWkd5FVY1NJib0tlhiorFZRzQN5Z+APLPr7aMolt7i2AQ==
electron@25.8.4:
version "25.8.4"
resolved "https://registry.yarnpkg.com/electron/-/electron-25.8.4.tgz#b50877aac7d96323920437baf309ad86382cb455"
integrity sha512-hUYS3RGdaa6E1UWnzeGnsdsBYOggwMMg4WGxNGvAoWtmRrr6J1BsjFW/yRq4WsJHJce2HdzQXtz4OGXV6yUCLg==
dependencies:
"@electron/get" "^2.0.0"
"@types/node" "^18.11.18"

View file

@ -3,7 +3,6 @@ import logging
import json
from enum import Enum
from typing import Optional
from fastapi import Request
import requests
import os
@ -21,15 +20,16 @@ from starlette.authentication import (
)
# Internal Packages
from database.models import KhojUser, Subscription
from database.adapters import get_all_users, get_or_create_search_model
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, load_content, configure_search
from khoj.utils import constants, state
from khoj.utils.config import (
SearchType,
)
from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig
from khoj.routers.indexer import configure_content, load_content, configure_search
from database.models import KhojUser, Subscription
from database.adapters import get_all_users
logger = logging.getLogger(__name__)
@ -113,14 +113,13 @@ def configure_server(
# Initialize Search Models from Config and initialize content
try:
state.config_lock.acquire()
state.SearchType = configure_search_types(state.config)
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder)
state.SearchType = configure_search_types()
state.search_models = configure_search(state.search_models, state.config.search_type)
initialize_content(regenerate, search_type, init, user)
except Exception as e:
raise e
finally:
state.config_lock.release()
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
@ -192,7 +191,7 @@ def update_search_index():
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)
def configure_search_types(config: FullConfig):
def configure_search_types():
# Extract core search types
core_search_types = {e.name: e.value for e in SearchType}

View file

@ -327,7 +327,7 @@ To get started, just start typing below. You can also type / to see a list of co
.then(data => {
if (data.detail) {
// If the server returns a 500 error with detail, render a setup hint.
renderMessage("Hi 👋🏾, to get started you have two options:<ol><li><b>Use OpenAI</b>: <ol><li>Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a></li><li>Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a></li><li>Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a></li></ol></li><li><b>Enable offline chat</b>: <ol><li>Go to the Khoj <a class='inline-chat-link' href='/config'>settings page</a> and enable offline chat</li></ol></li></ol>", "khoj");
renderMessage("Hi 👋🏾, to start chatting add available chat models options via <a class='inline-chat-link' href='/server/admin'>the Django Admin panel</a> on the Server", "khoj");
// Disable chat input field and update placeholder text
document.getElementById("chat-input").setAttribute("disabled", "disabled");

View file

@ -30,7 +30,7 @@ search-type:
encoder: sentence-transformers/all-MiniLM-L6-v2
encoder-type: null
model-directory: ~/.khoj/search/symmetric
version: 0.12.4
version: 0.14.0
The new version will looks like this:
@ -53,11 +53,7 @@ search-type:
asymmetric:
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
encoder: sentence-transformers/multi-qa-MiniLM-L6-cos-v1
image:
encoder: sentence-transformers/clip-ViT-B-32
encoder-type: null
model-directory: /Users/si/.khoj/search/image
version: 0.12.4
version: 0.15.0
"""
import logging
@ -68,6 +64,7 @@ from database.models import (
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
ChatModelOptions,
SearchModelConfig,
)
logger = logging.getLogger(__name__)
@ -87,6 +84,19 @@ def migrate_server_pg(args):
if raw_config is None:
return args
if "search-type" in raw_config and raw_config["search-type"]:
if "asymmetric" in raw_config["search-type"]:
# Delete all existing search models
SearchModelConfig.objects.filter(model_type=SearchModelConfig.ModelType.TEXT).delete()
# Create new search model from existing Khoj YAML config
asymmetric_search = raw_config["search-type"]["asymmetric"]
SearchModelConfig.objects.create(
name="default",
model_type=SearchModelConfig.ModelType.TEXT,
bi_encoder=asymmetric_search.get("encoder"),
cross_encoder=asymmetric_search.get("cross-encoder"),
)
if "processor" in raw_config and raw_config["processor"] and "conversation" in raw_config["processor"]:
processor_conversation = raw_config["processor"]["conversation"]

View file

@ -1,16 +1,17 @@
from typing import List
from sentence_transformers import SentenceTransformer, CrossEncoder
from torch import nn
from khoj.utils.helpers import get_device
from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel:
def __init__(self):
def __init__(self, model_name: str = "thenlper/gte-small"):
self.encode_kwargs = {"normalize_embeddings": True}
self.model_kwargs = {"device": get_device()}
self.model_name = "thenlper/gte-small"
self.model_name = model_name
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def embed_query(self, query):
@ -21,11 +22,11 @@ class EmbeddingsModel:
class CrossEncoderModel:
def __init__(self):
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.model_name = model_name
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
def predict(self, query, hits: List[SearchResponse]):
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
cross_inp = [[query, hit.additional[key]] for hit in hits]
cross_scores = self.cross_encoder_model.predict(cross_inp, activation_fct=nn.Sigmoid())
return cross_scores

View file

@ -6,12 +6,12 @@ import logging
import uuid
from tqdm import tqdm
from typing import Callable, List, Tuple, Set, Any
from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, timer, batcher
# Internal Packages
from khoj.utils.rawconfig import Entry
from khoj.processor.embeddings import EmbeddingsModel
from khoj.search_filter.date_filter import DateFilter
from database.models import KhojUser, Entry as DbEntry, EntryDates
from database.adapters import EntryAdapters
@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
class TextToEntries(ABC):
def __init__(self, config: Any = None):
self.embeddings_model = EmbeddingsModel()
self.embeddings_model = state.embeddings_model
self.config = config
self.date_filter = DateFilter()

View file

@ -376,7 +376,7 @@ async def search(
# initialize variables
user_query = q.strip()
results_count = n or 5
max_distance = max_distance if max_distance is not None else math.inf
max_distance = max_distance or math.inf
search_futures: List[concurrent.futures.Future] = []
# return cached results, if available
@ -581,7 +581,7 @@ async def chat(
request: Request,
q: str,
n: Optional[int] = 5,
d: Optional[float] = 0.15,
d: Optional[float] = 0.18,
client: Optional[str] = None,
stream: Optional[bool] = False,
user_agent: Optional[str] = Header(None),

View file

@ -16,6 +16,7 @@ from google.auth.transport import requests as google_requests
# Internal Packages
from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token
from database.models import KhojApiUser
from khoj.routers.helpers import update_telemetry_state
from khoj.utils import state
@ -51,12 +52,16 @@ async def login(request: Request):
@auth_router.post("/token")
@requires(["authenticated"], redirect="login_page")
async def generate_token(request: Request, token_name: Optional[str] = None) -> str:
async def generate_token(request: Request, token_name: Optional[str] = None):
"Generate API token for given user"
if token_name:
return await create_khoj_token(user=request.user.object, name=token_name)
token = await create_khoj_token(user=request.user.object, name=token_name)
else:
return await create_khoj_token(user=request.user.object)
token = await create_khoj_token(user=request.user.object)
return {
"token": token.token,
"name": token.name,
}
@auth_router.get("/token")

View file

@ -2,7 +2,7 @@
import logging
import math
from pathlib import Path
from typing import List, Tuple, Type, Union, Dict
from typing import List, Tuple, Type, Union
# External Packages
import torch

View file

@ -5,21 +5,20 @@ from typing import List, Dict
from collections import defaultdict
# External Packages
import torch
from pathlib import Path
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
# Internal Packages
from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU, get_device
from khoj.utils.rawconfig import FullConfig
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
# Application Global State
config = FullConfig()
search_models = SearchModels()
embeddings_model = EmbeddingsModel()
cross_encoder_model = CrossEncoderModel()
embeddings_model: EmbeddingsModel = None
cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex()
gpt4all_processor_config: GPT4AllProcessorModel = None
config_file: Path = None
@ -28,7 +27,6 @@ host: str = None
port: int = None
cli_args: List[str] = None
query_cache: Dict[str, LRU] = defaultdict(LRU)
config_lock = threading.Lock()
chat_lock = threading.Lock()
SearchType = utils_config.SearchType
telemetry: List[Dict[str, str]] = []

View file

@ -8,11 +8,13 @@ from fastapi import FastAPI
import os
from fastapi import FastAPI
app = FastAPI()
# Internal Packages
from khoj.configure import configure_routes, configure_search_types, configure_middleware
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchModels
@ -54,6 +56,9 @@ def enable_db_access_for_all_tests(db):
@pytest.fixture(scope="session")
def search_config() -> SearchConfig:
state.embeddings_model = EmbeddingsModel()
state.cross_encoder_model = CrossEncoderModel()
model_dir = resolve_absolute_path("~/.khoj/search")
model_dir.mkdir(parents=True, exist_ok=True)
search_config = SearchConfig()
@ -222,7 +227,7 @@ def md_content_config():
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
state.SearchType = configure_search_types()
LocalMarkdownConfig.objects.create(
input_files=None,
@ -256,7 +261,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
state.SearchType = configure_search_types()
# Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"):
@ -291,7 +296,9 @@ def client(
):
state.config.content_type = content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
state.SearchType = configure_search_types()
state.embeddings_model = EmbeddingsModel()
state.cross_encoder_model = CrossEncoderModel()
# These lines help us Mock the Search models for these search types
state.search_models.image_search = image_search.initialize_model(search_config.image)
@ -323,7 +330,7 @@ def client(
def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
state.SearchType = configure_search_types()
LocalMarkdownConfig.objects.create(
input_files=None,

View file

@ -7,6 +7,7 @@ from database.models import (
ChatModelOptions,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig,
SearchModelConfig,
UserConversationConfig,
Conversation,
Subscription,
@ -71,6 +72,16 @@ class ConversationFactory(factory.django.DjangoModelFactory):
user = factory.SubFactory(UserFactory)
class SearchModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = SearchModelConfig
name = "default"
model_type = "text"
bi_encoder = "thenlper/gte-small"
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2"
class SubscriptionFactory(factory.django.DjangoModelFactory):
class Meta:
model = Subscription

View file

@ -173,7 +173,6 @@ def test_regenerate_with_github_fails_without_pat(client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
@pytest.mark.skip(reason="Flaky test on parallel test runs")
def test_get_configured_types_via_api(client, sample_org_data):
# Act
text_search.setup(OrgToEntries, sample_org_data, regenerate=False)
@ -203,10 +202,10 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
@pytest.mark.django_db(transaction=True)
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
# Arrange
state.SearchType = configure_search_types(config)
original_config = state.config.content_type
state.config.content_type = None
state.anonymous_mode = True
if state.config and state.config.content_type:
state.config.content_type = None
state.search_models = configure_search_types()
configure_routes(fastapi_app)
client = TestClient(fastapi_app)
@ -218,9 +217,6 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
assert response.status_code == 200
assert response.json() == ["all"]
# Restore
state.config.content_type = original_config
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@ -259,13 +255,30 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa
user_query = quote("How to git install application?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true", headers=headers)
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
# Assert
assert response.status_code == 200
# assert actual_data contains "Khoj via Emacs" entry
assert len(response.json()) == 1, "Expected only 1 result"
search_result = response.json()[0]["entry"]
assert "git clone https://github.com/khoj-ai/khoj" in search_result
assert "git clone https://github.com/khoj-ai/khoj" in search_result, "Expected 'git clone' in search result"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_notes_search_no_results(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
user_query = quote("How to find my goat?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
# Assert
assert response.status_code == 200
assert response.json() == [], "Expected no results"
# ----------------------------------------------------------------------------------------------------