mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Wrap asymmetric search model into SearchModels. Test notes search end-to-end
- Wrap asymmetric search model parameters into AsymmetricSearchModel class - Create wrapper for all search type models. Put notes search model into it - Test notes search end-to-end from client API layer to results. Use model build on test data
This commit is contained in:
parent
cde11a2331
commit
e22e0b41e3
4 changed files with 51 additions and 27 deletions
20
src/main.py
20
src/main.py
|
@ -11,9 +11,11 @@ from fastapi import FastAPI
|
||||||
from search_type import asymmetric, symmetric_ledger, image_search
|
from search_type import asymmetric, symmetric_ledger, image_search
|
||||||
from utils.helpers import get_from_dict
|
from utils.helpers import get_from_dict
|
||||||
from utils.cli import cli
|
from utils.cli import cli
|
||||||
from utils.config import SearchType, SearchSettings
|
from utils.config import SearchType, SearchSettings, SearchModels
|
||||||
|
|
||||||
|
|
||||||
|
# Application Global State
|
||||||
|
model = SearchModels()
|
||||||
search_settings = SearchSettings()
|
search_settings = SearchSettings()
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@ -29,16 +31,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
|
|
||||||
if (t == SearchType.Notes or t == None) and search_settings.notes_search_enabled:
|
if (t == SearchType.Notes or t == None) and search_settings.notes_search_enabled:
|
||||||
# query notes
|
# query notes
|
||||||
hits = asymmetric.query_notes(
|
hits = asymmetric.query_notes(user_query, model.notes_search)
|
||||||
user_query,
|
|
||||||
corpus_embeddings,
|
|
||||||
entries,
|
|
||||||
bi_encoder,
|
|
||||||
cross_encoder,
|
|
||||||
top_k)
|
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
return asymmetric.collate_results(hits, entries, results_count)
|
return asymmetric.collate_results(hits, model.notes_search.entries, results_count)
|
||||||
|
|
||||||
if (t == SearchType.Music or t == None) and search_settings.music_search_enabled:
|
if (t == SearchType.Music or t == None) and search_settings.music_search_enabled:
|
||||||
# query music library
|
# query music library
|
||||||
|
@ -90,9 +86,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
def regenerate(t: Optional[SearchType] = None):
|
def regenerate(t: Optional[SearchType] = None):
|
||||||
if (t == SearchType.Notes or t == None) and search_settings.notes_search_enabled:
|
if (t == SearchType.Notes or t == None) and search_settings.notes_search_enabled:
|
||||||
# Extract Entries, Generate Embeddings
|
# Extract Entries, Generate Embeddings
|
||||||
global corpus_embeddings
|
models.notes_search = asymmetric.setup(
|
||||||
global entries
|
|
||||||
entries, corpus_embeddings, _, _, _ = asymmetric.setup(
|
|
||||||
org_config['input-files'],
|
org_config['input-files'],
|
||||||
org_config['input-filter'],
|
org_config['input-filter'],
|
||||||
pathlib.Path(org_config['compressed-jsonl']),
|
pathlib.Path(org_config['compressed-jsonl']),
|
||||||
|
@ -146,7 +140,7 @@ if __name__ == '__main__':
|
||||||
org_config = get_from_dict(args.config, 'content-type', 'org')
|
org_config = get_from_dict(args.config, 'content-type', 'org')
|
||||||
if org_config and ('input-files' in org_config or 'input-filter' in org_config):
|
if org_config and ('input-files' in org_config or 'input-filter' in org_config):
|
||||||
search_settings.notes_search_enabled = True
|
search_settings.notes_search_enabled = True
|
||||||
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup(
|
model.notes_search = asymmetric.setup(
|
||||||
org_config['input-files'],
|
org_config['input-files'],
|
||||||
org_config['input-filter'],
|
org_config['input-filter'],
|
||||||
pathlib.Path(org_config['compressed-jsonl']),
|
pathlib.Path(org_config['compressed-jsonl']),
|
||||||
|
|
|
@ -17,6 +17,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from utils.helpers import get_absolute_path, resolve_absolute_path
|
from utils.helpers import get_absolute_path, resolve_absolute_path
|
||||||
from processor.org_mode.org_to_jsonl import org_to_jsonl
|
from processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||||
|
from utils.config import AsymmetricSearchModel
|
||||||
|
|
||||||
|
|
||||||
def initialize_model():
|
def initialize_model():
|
||||||
|
@ -64,7 +65,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v
|
||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query_notes(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k=100):
|
def query_notes(raw_query: str, model: AsymmetricSearchModel):
|
||||||
"Search all notes for entries that answer the query"
|
"Search all notes for entries that answer the query"
|
||||||
# Separate natural query from explicit required, blocked words filters
|
# Separate natural query from explicit required, blocked words filters
|
||||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||||
|
@ -72,20 +73,22 @@ def query_notes(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder
|
||||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||||
|
|
||||||
# Encode the query using the bi-encoder
|
# Encode the query using the bi-encoder
|
||||||
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True)
|
||||||
|
|
||||||
# Find relevant entries for the query
|
# Find relevant entries for the query
|
||||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k)
|
||||||
hits = hits[0] # Get the hits for the first query
|
hits = hits[0] # Get the hits for the first query
|
||||||
|
|
||||||
# Filter results using explicit filters
|
# Filter results using explicit filters
|
||||||
hits = explicit_filter(hits, [entry[0] for entry in entries], required_words, blocked_words)
|
hits = explicit_filter(hits,
|
||||||
|
[entry[0] for entry in model.entries],
|
||||||
|
required_words,blocked_words)
|
||||||
if hits is None or len(hits) == 0:
|
if hits is None or len(hits) == 0:
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
cross_inp = [[query, entries[hit['corpus_id']][0]] for hit in hits]
|
cross_inp = [[query, model.entries[hit['corpus_id']][0]] for hit in hits]
|
||||||
cross_scores = cross_encoder.predict(cross_inp)
|
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||||
|
|
||||||
# Store cross-encoder scores in results dictionary for ranking
|
# Store cross-encoder scores in results dictionary for ranking
|
||||||
for idx in range(len(cross_scores)):
|
for idx in range(len(cross_scores)):
|
||||||
|
@ -161,7 +164,7 @@ def setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=Fa
|
||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, embeddings, regenerate=regenerate, verbose=verbose)
|
corpus_embeddings = compute_embeddings(entries, bi_encoder, embeddings, regenerate=regenerate, verbose=verbose)
|
||||||
|
|
||||||
return entries, corpus_embeddings, bi_encoder, cross_encoder, top_k
|
return AsymmetricSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -6,7 +6,7 @@ import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from main import app
|
from main import app, search_settings, model
|
||||||
from search_type import asymmetric
|
from search_type import asymmetric
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,18 +55,33 @@ def test_regenerate_with_valid_search_type():
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def test_asymmetric_setup():
|
def test_notes_search():
|
||||||
# Arrange
|
# Arrange
|
||||||
input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')]
|
input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')]
|
||||||
input_filter = None
|
input_filter = None
|
||||||
compressed_jsonl = Path('tests/data/.test.jsonl.gz')
|
compressed_jsonl = Path('tests/data/.test.jsonl.gz')
|
||||||
embeddings = Path('tests/data/.test_embeddings.pt')
|
embeddings = Path('tests/data/.test_embeddings.pt')
|
||||||
regenerate = False
|
|
||||||
verbose = 1
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = asymmetric.setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate, verbose)
|
# Regenerate embeddings during asymmetric setup
|
||||||
|
notes_model = asymmetric.setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=True, verbose=0)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 10
|
assert len(notes_model.entries) == 10
|
||||||
assert len(corpus_embeddings) == 10
|
assert len(notes_model.corpus_embeddings) == 10
|
||||||
|
|
||||||
|
# Arrange
|
||||||
|
model.notes_search = notes_model
|
||||||
|
search_settings.notes_search_enabled = True
|
||||||
|
user_query = "How to call semantic search from Emacs?"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/search?q={user_query}&n=1&t=notes")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
# assert actual_data contains "Semantic Search via Emacs"
|
||||||
|
search_result = response.json()[0]["Entry"]
|
||||||
|
assert "Semantic Search via Emacs" in search_result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,3 +18,15 @@ class SearchSettings():
|
||||||
image_search_enabled: bool = False
|
image_search_enabled: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class AsymmetricSearchModel():
|
||||||
|
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k):
|
||||||
|
self.entries = entries
|
||||||
|
self.corpus_embeddings = corpus_embeddings
|
||||||
|
self.bi_encoder = bi_encoder
|
||||||
|
self.cross_encoder = cross_encoder
|
||||||
|
self.top_k = top_k
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchModels():
|
||||||
|
notes_search: AsymmetricSearchModel = None
|
||||||
|
|
Loading…
Add table
Reference in a new issue