mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Add support per user for configuring the preferred search model from the config page
- Honor this setting across the relevant places where embeddings are used - Convert the VectorField object to have None for dimensions in order to make the search model easily configurable
This commit is contained in:
parent
0f6e4ff683
commit
5ff9df9d4c
10 changed files with 117 additions and 17 deletions
|
@ -66,7 +66,7 @@ dependencies = [
|
|||
"gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'",
|
||||
"itsdangerous == 2.1.2",
|
||||
"httpx == 0.25.0",
|
||||
"pgvector == 0.2.3",
|
||||
"pgvector == 0.2.4",
|
||||
"psycopg2-binary == 2.9.9",
|
||||
"google-auth == 2.23.3",
|
||||
"python-multipart == 0.0.6",
|
||||
|
|
|
@ -269,6 +269,14 @@ def get_or_create_search_models():
|
|||
return search_models
|
||||
|
||||
|
||||
async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
|
||||
config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
|
||||
if not config:
|
||||
return None
|
||||
new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
||||
return new_config
|
||||
|
||||
|
||||
class ConversationAdapters:
|
||||
@staticmethod
|
||||
def get_conversation_by_user(user: KhojUser):
|
||||
|
|
18
src/khoj/database/migrations/0024_alter_entry_embeddings.py
Normal file
18
src/khoj/database/migrations/0024_alter_entry_embeddings.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.2.7 on 2023-12-20 07:27
|
||||
|
||||
from django.db import migrations
|
||||
import pgvector.django
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0023_usersearchmodelconfig"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="entry",
|
||||
name="embeddings",
|
||||
field=pgvector.django.VectorField(),
|
||||
),
|
||||
]
|
|
@ -185,7 +185,7 @@ class Entry(BaseModel):
|
|||
GITHUB = "github"
|
||||
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
embeddings = VectorField(dimensions=384)
|
||||
embeddings = VectorField(dimensions=None)
|
||||
raw = models.TextField()
|
||||
compiled = models.TextField()
|
||||
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
|
||||
|
|
2
src/khoj/interface/web/assets/icons/matrix_blob.svg
Normal file
2
src/khoj/interface/web/assets/icons/matrix_blob.svg
Normal file
|
@ -0,0 +1,2 @@
|
|||
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 64 64" xmlns="http://www.w3.org/2000/svg" stroke-width="3" stroke="#000000" fill="none"><circle cx="34.52" cy="11.43" r="5.82"/><circle cx="53.63" cy="31.6" r="5.82"/><circle cx="34.52" cy="50.57" r="5.82"/><circle cx="15.16" cy="42.03" r="5.82"/><circle cx="15.16" cy="19.27" r="5.82"/><circle cx="34.51" cy="29.27" r="4.7"/><line x1="20.17" y1="16.3" x2="28.9" y2="12.93"/><line x1="38.6" y1="15.59" x2="49.48" y2="27.52"/><line x1="50.07" y1="36.2" x2="38.67" y2="46.49"/><line x1="18.36" y1="24.13" x2="30.91" y2="46.01"/><line x1="20.31" y1="44.74" x2="28.7" y2="48.63"/><line x1="17.34" y1="36.63" x2="31.37" y2="16.32"/><line x1="20.52" y1="21.55" x2="30.34" y2="27.1"/><line x1="39.22" y1="29.8" x2="47.81" y2="30.45"/><line x1="34.51" y1="33.98" x2="34.52" y2="44.74"/></svg>
|
After Width: | Height: | Size: 951 B |
|
@ -296,6 +296,7 @@
|
|||
height: 32px;
|
||||
}
|
||||
|
||||
select#search-models,
|
||||
select#chat-models {
|
||||
margin-bottom: 0;
|
||||
padding: 8px;
|
||||
|
|
|
@ -146,6 +146,26 @@
|
|||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="card-title-row">
|
||||
<img class="card-icon" src="/static/assets/icons/matrix_blob.svg" alt="Chat">
|
||||
<h3 class="card-title">
|
||||
Text Model
|
||||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<select id="search-models">
|
||||
{% for option in search_model_options %}
|
||||
<option value="{{ option.id }}" {% if option.id == selected_search_model_config %}selected{% endif %}>{{ option.name }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
<button id="save-search-model" class="card-button happy" onclick="updateSearchModel()">
|
||||
Save
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div id="clients" class="section">
|
||||
|
@ -266,6 +286,30 @@
|
|||
})
|
||||
};
|
||||
|
||||
function updateSearchModel() {
|
||||
const searchModel = document.getElementById("search-models").value;
|
||||
const saveSearchModelButton = document.getElementById("save-search-model");
|
||||
saveSearchModelButton.disabled = true;
|
||||
saveSearchModelButton.innerHTML = "Saving...";
|
||||
|
||||
fetch('/api/config/data/search/model?id=' + searchModel, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.status == "ok") {
|
||||
saveSearchModelButton.innerHTML = "Save";
|
||||
saveSearchModelButton.disabled = false;
|
||||
} else {
|
||||
saveSearchModelButton.innerHTML = "Error";
|
||||
saveSearchModelButton.disabled = false;
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
function clearContentType(content_source) {
|
||||
fetch('/api/config/data/content-source/' + content_source, {
|
||||
method: 'DELETE',
|
||||
|
|
|
@ -332,6 +332,31 @@ async def update_chat_model(
|
|||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api.post("/config/data/search/model", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def update_chat_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
new_config = await adapters.aset_user_search_model(user, int(id))
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_search_model",
|
||||
client=client,
|
||||
metadata={"search_model": new_config.setting.name},
|
||||
)
|
||||
|
||||
if new_config is None:
|
||||
return {"status": "error", "message": "Model not found"}
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# Create Routes
|
||||
@api.get("/config/data/default")
|
||||
def get_default_config_data():
|
||||
|
@ -410,14 +435,10 @@ async def search(
|
|||
defiltered_query = filter.defilter(defiltered_query)
|
||||
|
||||
encoded_asymmetric_query = None
|
||||
if t == SearchType.All or t != SearchType.Image:
|
||||
text_search_models: List[TextSearchModel] = [
|
||||
model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
|
||||
]
|
||||
if text_search_models:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||
if t != SearchType.Image:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if t in [
|
||||
|
@ -473,9 +494,9 @@ async def search(
|
|||
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||
|
||||
# Sort results across all content types and take top results
|
||||
results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
|
||||
:results_count
|
||||
]
|
||||
results = text_search.rerank_and_sort_results(
|
||||
results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
|
||||
)[:results_count]
|
||||
|
||||
# Cache results
|
||||
if user:
|
||||
|
|
|
@ -155,6 +155,11 @@ def config_page(request: Request):
|
|||
for conversation_option in conversation_options:
|
||||
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
|
||||
|
||||
search_model_options = adapters.get_or_create_search_models().all()
|
||||
all_search_model_options = list()
|
||||
for search_model_option in search_model_options:
|
||||
all_search_model_options.append({"name": search_model_option.name, "id": search_model_option.id})
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"config.html",
|
||||
context={
|
||||
|
@ -163,6 +168,7 @@ def config_page(request: Request):
|
|||
"anonymous_mode": state.anonymous_mode,
|
||||
"username": user.username,
|
||||
"conversation_options": all_conversation_options,
|
||||
"search_model_options": all_search_model_options,
|
||||
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
|
||||
"user_photo": user_picture,
|
||||
"billing_enabled": state.billing_enabled,
|
||||
|
|
|
@ -180,13 +180,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
|||
)
|
||||
|
||||
|
||||
def rerank_and_sort_results(hits, query, rank_results):
|
||||
def rerank_and_sort_results(hits, query, rank_results, search_model_name):
|
||||
# If we have more than one result and reranking is enabled
|
||||
rank_results = rank_results and len(list(hits)) > 1
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
hits = cross_encoder_score(query, hits)
|
||||
hits = cross_encoder_score(query, hits, search_model_name)
|
||||
|
||||
# Sort results by cross-encoder score followed by bi-encoder score
|
||||
hits = sort_results(rank_results=rank_results, hits=hits)
|
||||
|
@ -219,10 +219,10 @@ def setup(
|
|||
)
|
||||
|
||||
|
||||
def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:
|
||||
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
||||
"""Score all retrieved entries using the cross-encoder"""
|
||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_scores = state.cross_encoder_model.predict(query, hits)
|
||||
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
|
||||
|
||||
# Convert cross-encoder scores to distances and pass in hits for reranking
|
||||
for idx in range(len(cross_scores)):
|
||||
|
|
Loading…
Reference in a new issue