mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Make Search Model Configurable on Server (#544)
- Make search model configurable on server - Update migration script to get search model from `khoj.yml` to Postgres - Update first run message on Khoj Desktop and Web app landing page - Other miscellaneous bug fixes
This commit is contained in:
commit
208ddddc6a
24 changed files with 263 additions and 84 deletions
43
.github/workflows/dockerize_dev.yml
vendored
Normal file
43
.github/workflows/dockerize_dev.yml
vendored
Normal file
|
@ -0,0 +1,43 @@
|
|||
name: dockerize-dev
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- src/khoj/**
|
||||
- config/**
|
||||
- pyproject.toml
|
||||
- prod.Dockerfile
|
||||
- .github/workflows/dockerize_dev.yml
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
DOCKER_IMAGE_TAG: 'dev'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build Production Docker Image, Push to Container Registry
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.PAT }}
|
||||
|
||||
- name: 📦 Build and Push Docker Image
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: .
|
||||
file: prod.Dockerfile
|
||||
platforms: linux/amd64
|
||||
push: true
|
||||
tags: ghcr.io/${{ github.repository }}:${{ env.DOCKER_IMAGE_TAG }}
|
||||
build-args: |
|
||||
PORT=42110
|
|
@ -54,7 +54,7 @@ dependencies = [
|
|||
"transformers >= 4.28.0",
|
||||
"torch == 2.0.1",
|
||||
"uvicorn == 0.17.6",
|
||||
"aiohttp == 3.8.5",
|
||||
"aiohttp == 3.8.6",
|
||||
"langchain >= 0.0.331",
|
||||
"requests >= 2.26.0",
|
||||
"bs4 >= 0.0.1",
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import math
|
||||
from typing import Optional, Type, TypeVar, List
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Optional, Type, List
|
||||
from datetime import date, datetime
|
||||
import secrets
|
||||
from typing import Type, TypeVar, List
|
||||
from typing import Type, List
|
||||
from datetime import date, timezone
|
||||
|
||||
from django.db import models
|
||||
|
@ -11,10 +11,6 @@ from pgvector.django import CosineDistance
|
|||
from django.db.models.manager import BaseManager
|
||||
from django.db.models import Q
|
||||
from torch import Tensor
|
||||
from pgvector.django import CosineDistance
|
||||
from django.db.models.manager import BaseManager
|
||||
from django.db.models import Q
|
||||
from torch import Tensor
|
||||
|
||||
# Import sync_to_async from Django Channels
|
||||
from asgiref.sync import sync_to_async
|
||||
|
@ -31,6 +27,7 @@ from database.models import (
|
|||
GithubRepoConfig,
|
||||
Conversation,
|
||||
ChatModelOptions,
|
||||
SearchModelConfig,
|
||||
Subscription,
|
||||
UserConversationConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
|
@ -41,15 +38,6 @@ from khoj.search_filter.word_filter import WordFilter
|
|||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=models.Model)
|
||||
|
||||
|
||||
async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType:
|
||||
instance = await model_class.objects.filter(id=id).afirst()
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found")
|
||||
return instance
|
||||
|
||||
|
||||
async def set_notion_config(token: str, user: KhojUser):
|
||||
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
||||
|
@ -65,9 +53,7 @@ async def create_khoj_token(user: KhojUser, name=None):
|
|||
"Create Khoj API key for user"
|
||||
token = f"kk-{secrets.token_urlsafe(32)}"
|
||||
name = name or f"{generate_random_name().title()}"
|
||||
api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
||||
await api_config.asave()
|
||||
return api_config
|
||||
return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
||||
|
||||
|
||||
def get_khoj_tokens(user: KhojUser):
|
||||
|
@ -83,13 +69,16 @@ async def delete_khoj_token(user: KhojUser, token: str):
|
|||
async def get_or_create_user(token: dict) -> KhojUser:
|
||||
user = await get_user_by_token(token)
|
||||
if not user:
|
||||
user = await create_google_user(token)
|
||||
user = await create_user_by_google_token(token)
|
||||
return user
|
||||
|
||||
|
||||
async def create_google_user(token: dict) -> KhojUser:
|
||||
user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email"))
|
||||
async def create_user_by_google_token(token: dict) -> KhojUser:
|
||||
user, _ = await KhojUser.objects.filter(email=token.get("email")).aupdate_or_create(
|
||||
defaults={"username": token.get("email"), "email": token.get("email")}
|
||||
)
|
||||
await user.asave()
|
||||
|
||||
await GoogleUser.objects.acreate(
|
||||
sub=token.get("sub"),
|
||||
azp=token.get("azp"),
|
||||
|
@ -220,6 +209,14 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
|||
return config
|
||||
|
||||
|
||||
def get_or_create_search_model():
|
||||
search_model = SearchModelConfig.objects.filter().first()
|
||||
if not search_model:
|
||||
search_model = SearchModelConfig.objects.create()
|
||||
|
||||
return search_model
|
||||
|
||||
|
||||
class ConversationAdapters:
|
||||
@staticmethod
|
||||
def get_conversation_by_user(user: KhojUser):
|
||||
|
|
|
@ -8,6 +8,7 @@ from database.models import (
|
|||
ChatModelOptions,
|
||||
OpenAIProcessorConversationConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
SearchModelConfig,
|
||||
Subscription,
|
||||
)
|
||||
|
||||
|
@ -16,4 +17,5 @@ admin.site.register(KhojUser, UserAdmin)
|
|||
admin.site.register(ChatModelOptions)
|
||||
admin.site.register(OpenAIProcessorConversationConfig)
|
||||
admin.site.register(OfflineChatProcessorConversationConfig)
|
||||
admin.site.register(SearchModelConfig)
|
||||
admin.site.register(Subscription)
|
||||
|
|
32
src/database/migrations/0017_searchmodel.py
Normal file
32
src/database/migrations/0017_searchmodel.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-14 23:25
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0016_alter_subscription_renewal_date"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="SearchModel",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("name", models.CharField(default="default", max_length=200)),
|
||||
("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
|
||||
("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
|
||||
(
|
||||
"cross_encoder",
|
||||
models.CharField(
|
||||
blank=True, default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200, null=True
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -0,0 +1,30 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-16 01:13
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0017_searchmodel"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="SearchModelConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("name", models.CharField(default="default", max_length=200)),
|
||||
("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
|
||||
("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
|
||||
("cross_encoder", models.CharField(default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.DeleteModel(
|
||||
name="SearchModel",
|
||||
),
|
||||
]
|
|
@ -102,6 +102,16 @@ class LocalPlaintextConfig(BaseModel):
|
|||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class SearchModelConfig(BaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
TEXT = "text"
|
||||
|
||||
name = models.CharField(max_length=200, default="default")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
|
||||
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
|
||||
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||
|
||||
|
||||
class OpenAIProcessorConversationConfig(BaseModel):
|
||||
api_key = models.CharField(max_length=200)
|
||||
|
||||
|
|
|
@ -328,7 +328,15 @@
|
|||
.then(data => {
|
||||
if (data.detail) {
|
||||
// If the server returns a 500 error with detail, render a setup hint.
|
||||
renderMessage("Hi 👋🏾, to get started you have two options:<ol><li><b>Use OpenAI</b>: <ol><li>Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a></li><li>Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a></li><li>Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a></li></ol></li><li><b>Enable offline chat</b>: <ol><li>Go to the Khoj <a class='inline-chat-link' href='/config'>settings page</a> and enable offline chat</li></ol></li></ol>", "khoj");
|
||||
first_run_message = `Hi 👋🏾, to get started:
|
||||
<ol>
|
||||
<li>Generate an API token in the <a class='inline-chat-link' href="#" onclick="window.navigateAPI.navigateToWebSettings()">Khoj Web settings</a></li>
|
||||
<li>Paste it into the API Key field in the <a class='inline-chat-link' href="#" onclick="window.navigateAPI.navigateToSettings()">Khoj Desktop settings</a></li>
|
||||
</ol>`
|
||||
.trim()
|
||||
.replace(/(\r\n|\n|\r)/gm, "");
|
||||
|
||||
renderMessage(first_run_message, "khoj");
|
||||
|
||||
// Disable chat input field and update placeholder text
|
||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||
|
|
|
@ -396,6 +396,14 @@ app.whenReady().then(() => {
|
|||
event.reply('update-state', arg);
|
||||
});
|
||||
|
||||
ipcMain.on('navigate', (event, page) => {
|
||||
win.loadFile(page);
|
||||
});
|
||||
|
||||
ipcMain.on('navigateToWebApp', (event, page) => {
|
||||
shell.openExternal(`${store.get('hostURL')}/${page}`);
|
||||
});
|
||||
|
||||
ipcMain.handle('getFiles', getFiles);
|
||||
ipcMain.handle('getFolders', getFolders);
|
||||
|
||||
|
|
|
@ -10,14 +10,14 @@
|
|||
"main": "main.js",
|
||||
"private": false,
|
||||
"devDependencies": {
|
||||
"electron": "25.8.1"
|
||||
"electron": "25.8.4"
|
||||
},
|
||||
"scripts": {
|
||||
"start": "yarn electron ."
|
||||
},
|
||||
"dependencies": {
|
||||
"@todesktop/runtime": "^1.3.0",
|
||||
"axios": "^1.5.0",
|
||||
"axios": "^1.6.0",
|
||||
"cron": "^2.4.3",
|
||||
"electron-store": "^8.1.0",
|
||||
"fs": "^0.0.1-security"
|
||||
|
|
|
@ -57,3 +57,8 @@ contextBridge.exposeInMainWorld('tokenAPI', {
|
|||
contextBridge.exposeInMainWorld('appInfoAPI', {
|
||||
getInfo: (callback) => ipcRenderer.on('appInfo', callback)
|
||||
})
|
||||
|
||||
contextBridge.exposeInMainWorld('navigateAPI', {
|
||||
navigateToSettings: () => ipcRenderer.send('navigate', 'config.html'),
|
||||
navigateToWebSettings: () => ipcRenderer.send('navigateToWebApp', 'config'),
|
||||
})
|
||||
|
|
|
@ -163,10 +163,10 @@ atomically@^1.7.0:
|
|||
resolved "https://registry.yarnpkg.com/atomically/-/atomically-1.7.0.tgz#c07a0458432ea6dbc9a3506fffa424b48bccaafe"
|
||||
integrity sha512-Xcz9l0z7y9yQ9rdDaxlmaI4uJHf/T8g9hOEzJcsEqX2SjCj4J20uK7+ldkDHMbpJDK76wF7xEIgxc/vSlsfw5w==
|
||||
|
||||
axios@^1.5.0:
|
||||
version "1.5.0"
|
||||
resolved "https://registry.yarnpkg.com/axios/-/axios-1.5.0.tgz#f02e4af823e2e46a9768cfc74691fdd0517ea267"
|
||||
integrity sha512-D4DdjDo5CY50Qms0qGQTTw6Q44jl7zRwY7bthds06pUGfChBCTcQs+N743eFWGEd6pRTMd6A+I87aWyFV5wiZQ==
|
||||
axios@^1.6.0:
|
||||
version "1.6.2"
|
||||
resolved "https://registry.yarnpkg.com/axios/-/axios-1.6.2.tgz#de67d42c755b571d3e698df1b6504cde9b0ee9f2"
|
||||
integrity sha512-7i24Ri4pmDRfJTR7LDBhsOTtcm+9kjX5WiY1X3wIisx6G9So3pfMkEiU7emUBe46oceVImccTEM3k6C5dbVW8A==
|
||||
dependencies:
|
||||
follow-redirects "^1.15.0"
|
||||
form-data "^4.0.0"
|
||||
|
@ -379,10 +379,10 @@ electron-updater@^4.6.1:
|
|||
lodash.isequal "^4.5.0"
|
||||
semver "^7.3.5"
|
||||
|
||||
electron@25.8.1:
|
||||
version "25.8.1"
|
||||
resolved "https://registry.yarnpkg.com/electron/-/electron-25.8.1.tgz#092fab5a833db4d9240d4d6f36218cf7ca954f86"
|
||||
integrity sha512-GtcP1nMrROZfFg0+mhyj1hamrHvukfF6of2B/pcWxmWkd5FVY1NJib0tlhiorFZRzQN5Z+APLPr7aMolt7i2AQ==
|
||||
electron@25.8.4:
|
||||
version "25.8.4"
|
||||
resolved "https://registry.yarnpkg.com/electron/-/electron-25.8.4.tgz#b50877aac7d96323920437baf309ad86382cb455"
|
||||
integrity sha512-hUYS3RGdaa6E1UWnzeGnsdsBYOggwMMg4WGxNGvAoWtmRrr6J1BsjFW/yRq4WsJHJce2HdzQXtz4OGXV6yUCLg==
|
||||
dependencies:
|
||||
"@electron/get" "^2.0.0"
|
||||
"@types/node" "^18.11.18"
|
||||
|
|
|
@ -3,7 +3,6 @@ import logging
|
|||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from fastapi import Request
|
||||
import requests
|
||||
import os
|
||||
|
||||
|
@ -21,15 +20,16 @@ from starlette.authentication import (
|
|||
)
|
||||
|
||||
# Internal Packages
|
||||
from database.models import KhojUser, Subscription
|
||||
from database.adapters import get_all_users, get_or_create_search_model
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import (
|
||||
SearchType,
|
||||
)
|
||||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||
from database.models import KhojUser, Subscription
|
||||
from database.adapters import get_all_users
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -113,14 +113,13 @@ def configure_server(
|
|||
|
||||
# Initialize Search Models from Config and initialize content
|
||||
try:
|
||||
state.config_lock.acquire()
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
|
||||
state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder)
|
||||
state.SearchType = configure_search_types()
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
initialize_content(regenerate, search_type, init, user)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
|
||||
|
||||
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
|
||||
|
@ -192,7 +191,7 @@ def update_search_index():
|
|||
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)
|
||||
|
||||
|
||||
def configure_search_types(config: FullConfig):
|
||||
def configure_search_types():
|
||||
# Extract core search types
|
||||
core_search_types = {e.name: e.value for e in SearchType}
|
||||
|
||||
|
|
|
@ -327,7 +327,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
.then(data => {
|
||||
if (data.detail) {
|
||||
// If the server returns a 500 error with detail, render a setup hint.
|
||||
renderMessage("Hi 👋🏾, to get started you have two options:<ol><li><b>Use OpenAI</b>: <ol><li>Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a></li><li>Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a></li><li>Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a></li></ol></li><li><b>Enable offline chat</b>: <ol><li>Go to the Khoj <a class='inline-chat-link' href='/config'>settings page</a> and enable offline chat</li></ol></li></ol>", "khoj");
|
||||
renderMessage("Hi 👋🏾, to start chatting add available chat models options via <a class='inline-chat-link' href='/server/admin'>the Django Admin panel</a> on the Server", "khoj");
|
||||
|
||||
// Disable chat input field and update placeholder text
|
||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||
|
|
|
@ -30,7 +30,7 @@ search-type:
|
|||
encoder: sentence-transformers/all-MiniLM-L6-v2
|
||||
encoder-type: null
|
||||
model-directory: ~/.khoj/search/symmetric
|
||||
version: 0.12.4
|
||||
version: 0.14.0
|
||||
|
||||
|
||||
The new version will looks like this:
|
||||
|
@ -53,11 +53,7 @@ search-type:
|
|||
asymmetric:
|
||||
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||
encoder: sentence-transformers/multi-qa-MiniLM-L6-cos-v1
|
||||
image:
|
||||
encoder: sentence-transformers/clip-ViT-B-32
|
||||
encoder-type: null
|
||||
model-directory: /Users/si/.khoj/search/image
|
||||
version: 0.12.4
|
||||
version: 0.15.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
@ -68,6 +64,7 @@ from database.models import (
|
|||
OpenAIProcessorConversationConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
ChatModelOptions,
|
||||
SearchModelConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -87,6 +84,19 @@ def migrate_server_pg(args):
|
|||
if raw_config is None:
|
||||
return args
|
||||
|
||||
if "search-type" in raw_config and raw_config["search-type"]:
|
||||
if "asymmetric" in raw_config["search-type"]:
|
||||
# Delete all existing search models
|
||||
SearchModelConfig.objects.filter(model_type=SearchModelConfig.ModelType.TEXT).delete()
|
||||
# Create new search model from existing Khoj YAML config
|
||||
asymmetric_search = raw_config["search-type"]["asymmetric"]
|
||||
SearchModelConfig.objects.create(
|
||||
name="default",
|
||||
model_type=SearchModelConfig.ModelType.TEXT,
|
||||
bi_encoder=asymmetric_search.get("encoder"),
|
||||
cross_encoder=asymmetric_search.get("cross-encoder"),
|
||||
)
|
||||
|
||||
if "processor" in raw_config and raw_config["processor"] and "conversation" in raw_config["processor"]:
|
||||
processor_conversation = raw_config["processor"]["conversation"]
|
||||
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
from typing import List
|
||||
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
from torch import nn
|
||||
|
||||
from khoj.utils.helpers import get_device
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
|
||||
|
||||
class EmbeddingsModel:
|
||||
def __init__(self):
|
||||
def __init__(self, model_name: str = "thenlper/gte-small"):
|
||||
self.encode_kwargs = {"normalize_embeddings": True}
|
||||
self.model_kwargs = {"device": get_device()}
|
||||
self.model_name = "thenlper/gte-small"
|
||||
self.model_name = model_name
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
|
||||
def embed_query(self, query):
|
||||
|
@ -21,11 +22,11 @@ class EmbeddingsModel:
|
|||
|
||||
|
||||
class CrossEncoderModel:
|
||||
def __init__(self):
|
||||
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
||||
self.model_name = model_name
|
||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
||||
|
||||
def predict(self, query, hits: List[SearchResponse]):
|
||||
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
|
||||
cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
|
||||
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
||||
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
||||
cross_scores = self.cross_encoder_model.predict(cross_inp, activation_fct=nn.Sigmoid())
|
||||
return cross_scores
|
||||
|
|
|
@ -6,12 +6,12 @@ import logging
|
|||
import uuid
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, List, Tuple, Set, Any
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import is_none_or_empty, timer, batcher
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.processor.embeddings import EmbeddingsModel
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from database.models import KhojUser, Entry as DbEntry, EntryDates
|
||||
from database.adapters import EntryAdapters
|
||||
|
@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class TextToEntries(ABC):
|
||||
def __init__(self, config: Any = None):
|
||||
self.embeddings_model = EmbeddingsModel()
|
||||
self.embeddings_model = state.embeddings_model
|
||||
self.config = config
|
||||
self.date_filter = DateFilter()
|
||||
|
||||
|
|
|
@ -376,7 +376,7 @@ async def search(
|
|||
# initialize variables
|
||||
user_query = q.strip()
|
||||
results_count = n or 5
|
||||
max_distance = max_distance if max_distance is not None else math.inf
|
||||
max_distance = max_distance or math.inf
|
||||
search_futures: List[concurrent.futures.Future] = []
|
||||
|
||||
# return cached results, if available
|
||||
|
@ -581,7 +581,7 @@ async def chat(
|
|||
request: Request,
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
d: Optional[float] = 0.15,
|
||||
d: Optional[float] = 0.18,
|
||||
client: Optional[str] = None,
|
||||
stream: Optional[bool] = False,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
|
|
|
@ -16,6 +16,7 @@ from google.auth.transport import requests as google_requests
|
|||
|
||||
# Internal Packages
|
||||
from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token
|
||||
from database.models import KhojApiUser
|
||||
from khoj.routers.helpers import update_telemetry_state
|
||||
from khoj.utils import state
|
||||
|
||||
|
@ -51,12 +52,16 @@ async def login(request: Request):
|
|||
|
||||
@auth_router.post("/token")
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
async def generate_token(request: Request, token_name: Optional[str] = None) -> str:
|
||||
async def generate_token(request: Request, token_name: Optional[str] = None):
|
||||
"Generate API token for given user"
|
||||
if token_name:
|
||||
return await create_khoj_token(user=request.user.object, name=token_name)
|
||||
token = await create_khoj_token(user=request.user.object, name=token_name)
|
||||
else:
|
||||
return await create_khoj_token(user=request.user.object)
|
||||
token = await create_khoj_token(user=request.user.object)
|
||||
return {
|
||||
"token": token.token,
|
||||
"name": token.name,
|
||||
}
|
||||
|
||||
|
||||
@auth_router.get("/token")
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Type, Union, Dict
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
|
|
@ -5,21 +5,20 @@ from typing import List, Dict
|
|||
from collections import defaultdict
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import config as utils_config
|
||||
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import LRU, get_device
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
|
||||
|
||||
# Application Global State
|
||||
config = FullConfig()
|
||||
search_models = SearchModels()
|
||||
embeddings_model = EmbeddingsModel()
|
||||
cross_encoder_model = CrossEncoderModel()
|
||||
embeddings_model: EmbeddingsModel = None
|
||||
cross_encoder_model: CrossEncoderModel = None
|
||||
content_index = ContentIndex()
|
||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||
config_file: Path = None
|
||||
|
@ -28,7 +27,6 @@ host: str = None
|
|||
port: int = None
|
||||
cli_args: List[str] = None
|
||||
query_cache: Dict[str, LRU] = defaultdict(LRU)
|
||||
config_lock = threading.Lock()
|
||||
chat_lock = threading.Lock()
|
||||
SearchType = utils_config.SearchType
|
||||
telemetry: List[Dict[str, str]] = []
|
||||
|
|
|
@ -8,11 +8,13 @@ from fastapi import FastAPI
|
|||
import os
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_routes, configure_search_types, configure_middleware
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils.config import SearchModels
|
||||
|
@ -54,6 +56,9 @@ def enable_db_access_for_all_tests(db):
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_config() -> SearchConfig:
|
||||
state.embeddings_model = EmbeddingsModel()
|
||||
state.cross_encoder_model = CrossEncoderModel()
|
||||
|
||||
model_dir = resolve_absolute_path("~/.khoj/search")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
search_config = SearchConfig()
|
||||
|
@ -222,7 +227,7 @@ def md_content_config():
|
|||
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
||||
# Initialize app state
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.SearchType = configure_search_types()
|
||||
|
||||
LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
|
@ -256,7 +261,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
|||
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
|
||||
# Initialize app state
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.SearchType = configure_search_types()
|
||||
|
||||
# Initialize Processor from Config
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
|
@ -291,7 +296,9 @@ def client(
|
|||
):
|
||||
state.config.content_type = content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.SearchType = configure_search_types()
|
||||
state.embeddings_model = EmbeddingsModel()
|
||||
state.cross_encoder_model = CrossEncoderModel()
|
||||
|
||||
# These lines help us Mock the Search models for these search types
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
@ -323,7 +330,7 @@ def client(
|
|||
def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
|
||||
# Initialize app state
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.SearchType = configure_search_types()
|
||||
|
||||
LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
|
|
|
@ -7,6 +7,7 @@ from database.models import (
|
|||
ChatModelOptions,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
SearchModelConfig,
|
||||
UserConversationConfig,
|
||||
Conversation,
|
||||
Subscription,
|
||||
|
@ -71,6 +72,16 @@ class ConversationFactory(factory.django.DjangoModelFactory):
|
|||
user = factory.SubFactory(UserFactory)
|
||||
|
||||
|
||||
class SearchModelFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = SearchModelConfig
|
||||
|
||||
name = "default"
|
||||
model_type = "text"
|
||||
bi_encoder = "thenlper/gte-small"
|
||||
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
|
||||
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = Subscription
|
||||
|
|
|
@ -173,7 +173,6 @@ def test_regenerate_with_github_fails_without_pat(client):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.skip(reason="Flaky test on parallel test runs")
|
||||
def test_get_configured_types_via_api(client, sample_org_data):
|
||||
# Act
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False)
|
||||
|
@ -203,10 +202,10 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
|
|||
@pytest.mark.django_db(transaction=True)
|
||||
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
||||
# Arrange
|
||||
state.SearchType = configure_search_types(config)
|
||||
original_config = state.config.content_type
|
||||
state.config.content_type = None
|
||||
state.anonymous_mode = True
|
||||
if state.config and state.config.content_type:
|
||||
state.config.content_type = None
|
||||
state.search_models = configure_search_types()
|
||||
|
||||
configure_routes(fastapi_app)
|
||||
client = TestClient(fastapi_app)
|
||||
|
@ -218,9 +217,6 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
|||
assert response.status_code == 200
|
||||
assert response.json() == ["all"]
|
||||
|
||||
# Restore
|
||||
state.config.content_type = original_config
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
|
@ -259,13 +255,30 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa
|
|||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true", headers=headers)
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# assert actual_data contains "Khoj via Emacs" entry
|
||||
|
||||
assert len(response.json()) == 1, "Expected only 1 result"
|
||||
search_result = response.json()[0]["entry"]
|
||||
assert "git clone https://github.com/khoj-ai/khoj" in search_result
|
||||
assert "git clone https://github.com/khoj-ai/khoj" in search_result, "Expected 'git clone' in search result"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_no_results(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||
user_query = quote("How to find my goat?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [], "Expected no results"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
|
Loading…
Add table
Reference in a new issue