mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Merge pull request #587 from khoj-ai/features/search-model-options-custom
Support multiple search models, with ability for custom user config
This commit is contained in:
commit
738f050086
18 changed files with 239 additions and 36 deletions
|
@ -66,7 +66,7 @@ dependencies = [
|
||||||
"gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'",
|
"gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'",
|
||||||
"itsdangerous == 2.1.2",
|
"itsdangerous == 2.1.2",
|
||||||
"httpx == 0.25.0",
|
"httpx == 0.25.0",
|
||||||
"pgvector == 0.2.3",
|
"pgvector == 0.2.4",
|
||||||
"psycopg2-binary == 2.9.9",
|
"psycopg2-binary == 2.9.9",
|
||||||
"google-auth == 2.23.3",
|
"google-auth == 2.23.3",
|
||||||
"python-multipart == 0.0.6",
|
"python-multipart == 0.0.6",
|
||||||
|
|
|
@ -25,7 +25,7 @@ from khoj.database.models import KhojUser, Subscription
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
get_all_users,
|
get_all_users,
|
||||||
get_or_create_search_model,
|
get_or_create_search_models,
|
||||||
aget_user_subscription_state,
|
aget_user_subscription_state,
|
||||||
SubscriptionState,
|
SubscriptionState,
|
||||||
)
|
)
|
||||||
|
@ -146,8 +146,14 @@ def configure_server(
|
||||||
|
|
||||||
# Initialize Search Models from Config and initialize content
|
# Initialize Search Models from Config and initialize content
|
||||||
try:
|
try:
|
||||||
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
|
search_models = get_or_create_search_models()
|
||||||
state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder)
|
state.embeddings_model = dict()
|
||||||
|
state.cross_encoder_model = dict()
|
||||||
|
|
||||||
|
for model in search_models:
|
||||||
|
state.embeddings_model.update({model.name: EmbeddingsModel(model.bi_encoder)})
|
||||||
|
state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)})
|
||||||
|
|
||||||
state.SearchType = configure_search_types()
|
state.SearchType = configure_search_types()
|
||||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
initialize_content(regenerate, search_type, init, user)
|
initialize_content(regenerate, search_type, init, user)
|
||||||
|
|
|
@ -32,6 +32,7 @@ from khoj.database.models import (
|
||||||
SpeechToTextModelOptions,
|
SpeechToTextModelOptions,
|
||||||
Subscription,
|
Subscription,
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
|
UserSearchModelConfig,
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
OfflineChatProcessorConversationConfig,
|
OfflineChatProcessorConversationConfig,
|
||||||
ReflectiveQuestion,
|
ReflectiveQuestion,
|
||||||
|
@ -250,12 +251,33 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_search_model():
|
def get_user_search_model_or_default(user=None):
|
||||||
search_model = SearchModelConfig.objects.filter().first()
|
if user and UserSearchModelConfig.objects.filter(user=user).exists():
|
||||||
if not search_model:
|
return UserSearchModelConfig.objects.filter(user=user).first().setting
|
||||||
search_model = SearchModelConfig.objects.create()
|
|
||||||
|
|
||||||
return search_model
|
if SearchModelConfig.objects.filter(name="default").exists():
|
||||||
|
return SearchModelConfig.objects.filter(name="default").first()
|
||||||
|
else:
|
||||||
|
SearchModelConfig.objects.create()
|
||||||
|
|
||||||
|
return SearchModelConfig.objects.first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_create_search_models():
|
||||||
|
search_models = SearchModelConfig.objects.all()
|
||||||
|
if search_models.count() == 0:
|
||||||
|
SearchModelConfig.objects.create()
|
||||||
|
search_models = SearchModelConfig.objects.all()
|
||||||
|
|
||||||
|
return search_models
|
||||||
|
|
||||||
|
|
||||||
|
async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
|
||||||
|
config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
class ConversationAdapters:
|
class ConversationAdapters:
|
||||||
|
|
|
@ -16,6 +16,7 @@ from khoj.database.models import (
|
||||||
SpeechToTextModelOptions,
|
SpeechToTextModelOptions,
|
||||||
Subscription,
|
Subscription,
|
||||||
ReflectiveQuestion,
|
ReflectiveQuestion,
|
||||||
|
UserSearchModelConfig,
|
||||||
TextToImageModelConfig,
|
TextToImageModelConfig,
|
||||||
Conversation,
|
Conversation,
|
||||||
)
|
)
|
||||||
|
@ -29,6 +30,7 @@ admin.site.register(OfflineChatProcessorConversationConfig)
|
||||||
admin.site.register(SearchModelConfig)
|
admin.site.register(SearchModelConfig)
|
||||||
admin.site.register(Subscription)
|
admin.site.register(Subscription)
|
||||||
admin.site.register(ReflectiveQuestion)
|
admin.site.register(ReflectiveQuestion)
|
||||||
|
admin.site.register(UserSearchModelConfig)
|
||||||
admin.site.register(TextToImageModelConfig)
|
admin.site.register(TextToImageModelConfig)
|
||||||
|
|
||||||
|
|
||||||
|
|
33
src/khoj/database/migrations/0023_usersearchmodelconfig.py
Normal file
33
src/khoj/database/migrations/0023_usersearchmodelconfig.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# Generated by Django 4.2.7 on 2023-12-19 15:44
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0022_texttoimagemodelconfig"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="UserSearchModelConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
(
|
||||||
|
"setting",
|
||||||
|
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.searchmodelconfig"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"user",
|
||||||
|
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
18
src/khoj/database/migrations/0024_alter_entry_embeddings.py
Normal file
18
src/khoj/database/migrations/0024_alter_entry_embeddings.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# Generated by Django 4.2.7 on 2023-12-20 07:27
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
import pgvector.django
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0023_usersearchmodelconfig"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="entry",
|
||||||
|
name="embeddings",
|
||||||
|
field=pgvector.django.VectorField(),
|
||||||
|
),
|
||||||
|
]
|
|
@ -153,6 +153,11 @@ class UserConversationConfig(BaseModel):
|
||||||
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
|
class UserSearchModelConfig(BaseModel):
|
||||||
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
|
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class Conversation(BaseModel):
|
class Conversation(BaseModel):
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
conversation_log = models.JSONField(default=dict)
|
conversation_log = models.JSONField(default=dict)
|
||||||
|
@ -180,7 +185,7 @@ class Entry(BaseModel):
|
||||||
GITHUB = "github"
|
GITHUB = "github"
|
||||||
|
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
embeddings = VectorField(dimensions=384)
|
embeddings = VectorField(dimensions=None)
|
||||||
raw = models.TextField()
|
raw = models.TextField()
|
||||||
compiled = models.TextField()
|
compiled = models.TextField()
|
||||||
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
|
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
|
||||||
|
|
2
src/khoj/interface/web/assets/icons/web.svg
Normal file
2
src/khoj/interface/web/assets/icons/web.svg
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||||
|
<svg width="800px" height="800px" viewBox="0 0 64 64" xmlns="http://www.w3.org/2000/svg" stroke-width="3" stroke="#000000" fill="none"><circle cx="34.52" cy="11.43" r="5.82"/><circle cx="53.63" cy="31.6" r="5.82"/><circle cx="34.52" cy="50.57" r="5.82"/><circle cx="15.16" cy="42.03" r="5.82"/><circle cx="15.16" cy="19.27" r="5.82"/><circle cx="34.51" cy="29.27" r="4.7"/><line x1="20.17" y1="16.3" x2="28.9" y2="12.93"/><line x1="38.6" y1="15.59" x2="49.48" y2="27.52"/><line x1="50.07" y1="36.2" x2="38.67" y2="46.49"/><line x1="18.36" y1="24.13" x2="30.91" y2="46.01"/><line x1="20.31" y1="44.74" x2="28.7" y2="48.63"/><line x1="17.34" y1="36.63" x2="31.37" y2="16.32"/><line x1="20.52" y1="21.55" x2="30.34" y2="27.1"/><line x1="39.22" y1="29.8" x2="47.81" y2="30.45"/><line x1="34.51" y1="33.98" x2="34.52" y2="44.74"/></svg>
|
After Width: | Height: | Size: 951 B |
|
@ -296,6 +296,21 @@
|
||||||
height: 32px;
|
height: 32px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
div#notification-banner {
|
||||||
|
background-color: var(--primary);
|
||||||
|
border: 1px solid var(--primary-hover);
|
||||||
|
padding: 8px;
|
||||||
|
box-shadow: 0px 1px 3px 0px rgba(0,0,0,0.1),0px 1px 2px -1px rgba(0,0,0,0.8);
|
||||||
|
}
|
||||||
|
|
||||||
|
div#notification-banner-parent {
|
||||||
|
position: fixed;
|
||||||
|
right: 0;
|
||||||
|
bottom: 0;
|
||||||
|
margin: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
select#search-models,
|
||||||
select#chat-models {
|
select#chat-models {
|
||||||
margin-bottom: 0;
|
margin-bottom: 0;
|
||||||
padding: 8px;
|
padding: 8px;
|
||||||
|
|
|
@ -16,6 +16,9 @@ Hi, I am Khoj, your open, personal AI 👋🏽. I can help:
|
||||||
- 🧠 Answer general knowledge questions
|
- 🧠 Answer general knowledge questions
|
||||||
- 💡 Be a sounding board for your ideas
|
- 💡 Be a sounding board for your ideas
|
||||||
- 📜 Chat with your notes & documents
|
- 📜 Chat with your notes & documents
|
||||||
|
- 🌄 Generate images based on your messages
|
||||||
|
- 🔎 Search the web for answers to your questions
|
||||||
|
- 🎙️ Listen to your audio messages
|
||||||
|
|
||||||
Get the Khoj [Desktop](https://khoj.dev/downloads), [Obsidian](https://docs.khoj.dev/#/obsidian?id=setup) or [Emacs](https://docs.khoj.dev/#/emacs?id=setup) app to search, chat with your 🖥️ computer docs.
|
Get the Khoj [Desktop](https://khoj.dev/downloads), [Obsidian](https://docs.khoj.dev/#/obsidian?id=setup) or [Emacs](https://docs.khoj.dev/#/emacs?id=setup) app to search, chat with your 🖥️ computer docs.
|
||||||
|
|
||||||
|
|
|
@ -107,6 +107,26 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="card">
|
||||||
|
<div class="card-title-row">
|
||||||
|
<img class="card-icon" src="/static/assets/icons/web.svg" alt="Language">
|
||||||
|
<h3 class="card-title">
|
||||||
|
Language
|
||||||
|
</h3>
|
||||||
|
</div>
|
||||||
|
<div class="card-description-row">
|
||||||
|
<select id="search-models">
|
||||||
|
{% for option in search_model_options %}
|
||||||
|
<option value="{{ option.id }}" {% if option.id == selected_search_model_config %}selected{% endif %}>{{ option.name }}</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div class="card-action-row">
|
||||||
|
<button id="save-search-model" class="card-button happy" onclick="updateSearchModel()">
|
||||||
|
Save
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="general-settings section">
|
<div class="general-settings section">
|
||||||
<div id="status" style="display: none;"></div>
|
<div id="status" style="display: none;"></div>
|
||||||
|
@ -239,6 +259,10 @@
|
||||||
</div>
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
<div class="section"></div>
|
<div class="section"></div>
|
||||||
|
</div>
|
||||||
|
<div class="section" id="notification-banner-parent">
|
||||||
|
<div id="notification-banner"></div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<script>
|
<script>
|
||||||
|
|
||||||
|
@ -266,6 +290,37 @@
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
|
function updateSearchModel() {
|
||||||
|
const searchModel = document.getElementById("search-models").value;
|
||||||
|
const saveSearchModelButton = document.getElementById("save-search-model");
|
||||||
|
saveSearchModelButton.disabled = true;
|
||||||
|
saveSearchModelButton.innerHTML = "Saving...";
|
||||||
|
|
||||||
|
fetch('/api/config/data/search/model?id=' + searchModel, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
if (data.status == "ok") {
|
||||||
|
saveSearchModelButton.innerHTML = "Save";
|
||||||
|
saveSearchModelButton.disabled = false;
|
||||||
|
} else {
|
||||||
|
saveSearchModelButton.innerHTML = "Error";
|
||||||
|
saveSearchModelButton.disabled = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let notificationBanner = document.getElementById("notification-banner");
|
||||||
|
notificationBanner.innerHTML = "When updating the language model, be sure to delete all your saved content and re-initialize.";
|
||||||
|
notificationBanner.style.display = "block";
|
||||||
|
setTimeout(function() {
|
||||||
|
notificationBanner.style.display = "none";
|
||||||
|
}, 5000);
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
function clearContentType(content_source) {
|
function clearContentType(content_source) {
|
||||||
fetch('/api/config/data/content-source/' + content_source, {
|
fetch('/api/config/data/content-source/' + content_source, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
|
|
|
@ -14,7 +14,7 @@ from khoj.utils.helpers import is_none_or_empty, timer, batcher
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.database.models import KhojUser, Entry as DbEntry, EntryDates
|
from khoj.database.models import KhojUser, Entry as DbEntry, EntryDates
|
||||||
from khoj.database.adapters import EntryAdapters
|
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -112,7 +112,8 @@ class TextToEntries(ABC):
|
||||||
with timer("Generated embeddings for entries to add to database in", logger):
|
with timer("Generated embeddings for entries to add to database in", logger):
|
||||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||||
embeddings += self.embeddings_model.embed_documents(data_to_embed)
|
model = get_user_search_model_or_default(user)
|
||||||
|
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
|
||||||
|
|
||||||
added_entries: list[DbEntry] = []
|
added_entries: list[DbEntry] = []
|
||||||
with timer("Added entries to database in", logger):
|
with timer("Added entries to database in", logger):
|
||||||
|
|
|
@ -18,7 +18,7 @@ from starlette.authentication import requires
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_server
|
from khoj.configure import configure_server
|
||||||
from khoj.database import adapters
|
from khoj.database import adapters
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_search_model_or_default
|
||||||
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
|
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
|
@ -332,6 +332,31 @@ async def update_chat_model(
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api.post("/config/data/search/model", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def update_search_model(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
user = request.user.object
|
||||||
|
|
||||||
|
new_config = await adapters.aset_user_search_model(user, int(id))
|
||||||
|
|
||||||
|
if new_config is None:
|
||||||
|
return {"status": "error", "message": "Model not found"}
|
||||||
|
else:
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_search_model",
|
||||||
|
client=client,
|
||||||
|
metadata={"search_model": new_config.setting.name},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
# Create Routes
|
# Create Routes
|
||||||
@api.get("/config/data/default")
|
@api.get("/config/data/default")
|
||||||
def get_default_config_data():
|
def get_default_config_data():
|
||||||
|
@ -410,13 +435,10 @@ async def search(
|
||||||
defiltered_query = filter.defilter(defiltered_query)
|
defiltered_query = filter.defilter(defiltered_query)
|
||||||
|
|
||||||
encoded_asymmetric_query = None
|
encoded_asymmetric_query = None
|
||||||
if t == SearchType.All or t != SearchType.Image:
|
if t != SearchType.Image:
|
||||||
text_search_models: List[TextSearchModel] = [
|
|
||||||
model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
|
|
||||||
]
|
|
||||||
if text_search_models:
|
|
||||||
with timer("Encoding query took", logger=logger):
|
with timer("Encoding query took", logger=logger):
|
||||||
encoded_asymmetric_query = state.embeddings_model.embed_query(defiltered_query)
|
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||||
|
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
if t in [
|
if t in [
|
||||||
|
@ -472,9 +494,9 @@ async def search(
|
||||||
results += text_search.collate_results(hits, dedupe=dedupe)
|
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||||
|
|
||||||
# Sort results across all content types and take top results
|
# Sort results across all content types and take top results
|
||||||
results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
|
results = text_search.rerank_and_sort_results(
|
||||||
:results_count
|
results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
|
||||||
]
|
)[:results_count]
|
||||||
|
|
||||||
# Cache results
|
# Cache results
|
||||||
if user:
|
if user:
|
||||||
|
|
|
@ -143,6 +143,13 @@ async def update(
|
||||||
)
|
)
|
||||||
return Response(content="Failed", status_code=500)
|
return Response(content="Failed", status_code=500)
|
||||||
|
|
||||||
|
indexing_metadata = {
|
||||||
|
"num_org": len(org_files),
|
||||||
|
"num_markdown": len(markdown_files),
|
||||||
|
"num_pdf": len(pdf_files),
|
||||||
|
"num_plaintext": len(plaintext_files),
|
||||||
|
}
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
|
@ -151,6 +158,7 @@ async def update(
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
referer=referer,
|
referer=referer,
|
||||||
host=host,
|
host=host,
|
||||||
|
metadata=indexing_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"📪 Content index updated via API call by {client} client")
|
logger.info(f"📪 Content index updated via API call by {client} client")
|
||||||
|
|
|
@ -155,6 +155,11 @@ def config_page(request: Request):
|
||||||
for conversation_option in conversation_options:
|
for conversation_option in conversation_options:
|
||||||
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
|
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
|
||||||
|
|
||||||
|
search_model_options = adapters.get_or_create_search_models().all()
|
||||||
|
all_search_model_options = list()
|
||||||
|
for search_model_option in search_model_options:
|
||||||
|
all_search_model_options.append({"name": search_model_option.name, "id": search_model_option.id})
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
"config.html",
|
"config.html",
|
||||||
context={
|
context={
|
||||||
|
@ -163,6 +168,7 @@ def config_page(request: Request):
|
||||||
"anonymous_mode": state.anonymous_mode,
|
"anonymous_mode": state.anonymous_mode,
|
||||||
"username": user.username,
|
"username": user.username,
|
||||||
"conversation_options": all_conversation_options,
|
"conversation_options": all_conversation_options,
|
||||||
|
"search_model_options": all_search_model_options,
|
||||||
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
|
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
|
||||||
"user_photo": user_picture,
|
"user_photo": user_picture,
|
||||||
"billing_enabled": state.billing_enabled,
|
"billing_enabled": state.billing_enabled,
|
||||||
|
|
|
@ -19,7 +19,7 @@ from khoj.utils.state import SearchType
|
||||||
from khoj.utils.rawconfig import SearchResponse, Entry
|
from khoj.utils.rawconfig import SearchResponse, Entry
|
||||||
from khoj.utils.jsonl import load_jsonl
|
from khoj.utils.jsonl import load_jsonl
|
||||||
from khoj.processor.content.text_to_entries import TextToEntries
|
from khoj.processor.content.text_to_entries import TextToEntries
|
||||||
from khoj.database.adapters import EntryAdapters
|
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||||
from khoj.database.models import KhojUser, Entry as DbEntry
|
from khoj.database.models import KhojUser, Entry as DbEntry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -115,7 +115,8 @@ async def query(
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
if question_embedding is None:
|
if question_embedding is None:
|
||||||
with timer("Query Encode Time", logger, state.device):
|
with timer("Query Encode Time", logger, state.device):
|
||||||
question_embedding = state.embeddings_model.embed_query(query)
|
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||||
|
question_embedding = state.embeddings_model[search_model.name].embed_query(query)
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
top_k = 10
|
top_k = 10
|
||||||
|
@ -179,13 +180,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def rerank_and_sort_results(hits, query, rank_results):
|
def rerank_and_sort_results(hits, query, rank_results, search_model_name):
|
||||||
# If we have more than one result and reranking is enabled
|
# If we have more than one result and reranking is enabled
|
||||||
rank_results = rank_results and len(list(hits)) > 1
|
rank_results = rank_results and len(list(hits)) > 1
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
if rank_results:
|
if rank_results:
|
||||||
hits = cross_encoder_score(query, hits)
|
hits = cross_encoder_score(query, hits, search_model_name)
|
||||||
|
|
||||||
# Sort results by cross-encoder score followed by bi-encoder score
|
# Sort results by cross-encoder score followed by bi-encoder score
|
||||||
hits = sort_results(rank_results=rank_results, hits=hits)
|
hits = sort_results(rank_results=rank_results, hits=hits)
|
||||||
|
@ -218,10 +219,10 @@ def setup(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:
|
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
||||||
"""Score all retrieved entries using the cross-encoder"""
|
"""Score all retrieved entries using the cross-encoder"""
|
||||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||||
cross_scores = state.cross_encoder_model.predict(query, hits)
|
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
|
||||||
|
|
||||||
# Convert cross-encoder scores to distances and pass in hits for reranking
|
# Convert cross-encoder scores to distances and pass in hits for reranking
|
||||||
for idx in range(len(cross_scores)):
|
for idx in range(len(cross_scores)):
|
||||||
|
|
|
@ -19,8 +19,8 @@ from khoj.utils.rawconfig import FullConfig
|
||||||
# Application Global State
|
# Application Global State
|
||||||
config = FullConfig()
|
config = FullConfig()
|
||||||
search_models = SearchModels()
|
search_models = SearchModels()
|
||||||
embeddings_model: EmbeddingsModel = None
|
embeddings_model: Dict[str, EmbeddingsModel] = None
|
||||||
cross_encoder_model: CrossEncoderModel = None
|
cross_encoder_model: Dict[str, CrossEncoderModel] = None
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
openai_client: OpenAI = None
|
openai_client: OpenAI = None
|
||||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||||
|
|
|
@ -45,8 +45,10 @@ def enable_db_access_for_all_tests(db):
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def search_config() -> SearchConfig:
|
def search_config() -> SearchConfig:
|
||||||
state.embeddings_model = EmbeddingsModel()
|
state.embeddings_model = dict()
|
||||||
state.cross_encoder_model = CrossEncoderModel()
|
state.embeddings_model["default"] = EmbeddingsModel()
|
||||||
|
state.cross_encoder_model = dict()
|
||||||
|
state.cross_encoder_model["default"] = CrossEncoderModel()
|
||||||
|
|
||||||
model_dir = resolve_absolute_path("~/.khoj/search")
|
model_dir = resolve_absolute_path("~/.khoj/search")
|
||||||
model_dir.mkdir(parents=True, exist_ok=True)
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -317,8 +319,10 @@ def client(
|
||||||
state.config.content_type = content_config
|
state.config.content_type = content_config
|
||||||
state.config.search_type = search_config
|
state.config.search_type = search_config
|
||||||
state.SearchType = configure_search_types()
|
state.SearchType = configure_search_types()
|
||||||
state.embeddings_model = EmbeddingsModel()
|
state.embeddings_model = dict()
|
||||||
state.cross_encoder_model = CrossEncoderModel()
|
state.embeddings_model["default"] = EmbeddingsModel()
|
||||||
|
state.cross_encoder_model = dict()
|
||||||
|
state.cross_encoder_model["default"] = CrossEncoderModel()
|
||||||
|
|
||||||
# These lines help us Mock the Search models for these search types
|
# These lines help us Mock the Search models for these search types
|
||||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
|
Loading…
Reference in a new issue