mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Reorder embeddings search arguments based on argument importance
This commit is contained in:
parent
0eacc0b2b0
commit
6a8fd9bf33
4 changed files with 7 additions and 7 deletions
|
@ -1413,11 +1413,11 @@ class EntryAdapters:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def search_with_embeddings(
|
def search_with_embeddings(
|
||||||
user: KhojUser,
|
raw_query: str,
|
||||||
embeddings: Tensor,
|
embeddings: Tensor,
|
||||||
|
user: KhojUser,
|
||||||
max_results: int = 10,
|
max_results: int = 10,
|
||||||
file_type_filter: str = None,
|
file_type_filter: str = None,
|
||||||
raw_query: str = None,
|
|
||||||
max_distance: float = math.inf,
|
max_distance: float = math.inf,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
|
|
|
@ -160,8 +160,8 @@ async def execute_search(
|
||||||
search_futures += [
|
search_futures += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user,
|
|
||||||
user_query,
|
user_query,
|
||||||
|
user,
|
||||||
t,
|
t,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
max_distance=max_distance,
|
max_distance=max_distance,
|
||||||
|
|
|
@ -97,8 +97,8 @@ def load_embeddings(
|
||||||
|
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
user: KhojUser,
|
|
||||||
raw_query: str,
|
raw_query: str,
|
||||||
|
user: KhojUser,
|
||||||
type: SearchType = SearchType.All,
|
type: SearchType = SearchType.All,
|
||||||
question_embedding: Union[torch.Tensor, None] = None,
|
question_embedding: Union[torch.Tensor, None] = None,
|
||||||
max_distance: float = None,
|
max_distance: float = None,
|
||||||
|
@ -125,12 +125,12 @@ async def query(
|
||||||
top_k = 10
|
top_k = 10
|
||||||
with timer("Search Time", logger, state.device):
|
with timer("Search Time", logger, state.device):
|
||||||
hits = EntryAdapters.search_with_embeddings(
|
hits = EntryAdapters.search_with_embeddings(
|
||||||
user=user,
|
raw_query=raw_query,
|
||||||
embeddings=question_embedding,
|
embeddings=question_embedding,
|
||||||
max_results=top_k,
|
max_results=top_k,
|
||||||
file_type_filter=file_type,
|
file_type_filter=file_type,
|
||||||
raw_query=raw_query,
|
|
||||||
max_distance=max_distance,
|
max_distance=max_distance,
|
||||||
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
).all()
|
).all()
|
||||||
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
||||||
|
|
|
@ -164,7 +164,7 @@ async def test_text_search(search_config: SearchConfig):
|
||||||
query = "Load Khoj on Emacs?"
|
query = "Load Khoj on Emacs?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
hits = await text_search.query(default_user, query)
|
hits = await text_search.query(query, default_user)
|
||||||
results = text_search.collate_results(hits)
|
results = text_search.collate_results(hits)
|
||||||
results = sorted(results, key=lambda x: float(x.score))[:1]
|
results = sorted(results, key=lambda x: float(x.score))[:1]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue