Use async/await to fix parallelization of search across content types

This commit is contained in:
Debanjum Singh Solanky 2023-06-20 19:52:57 -07:00
parent 1192e49307
commit 5c7c8d1f46
2 changed files with 5 additions and 5 deletions

View file

@ -114,7 +114,7 @@ async def set_processor_conversation_config_data(updated_config: ConversationPro
@api.get("/search", response_model=List[SearchResponse]) @api.get("/search", response_model=List[SearchResponse])
def search( async def search(
q: str, q: str,
n: Optional[int] = 5, n: Optional[int] = 5,
t: Optional[SearchType] = None, t: Optional[SearchType] = None,
@ -257,9 +257,9 @@ def search(
# Query across each requested content types in parallel # Query across each requested content types in parallel
with timer("Query took", logger): with timer("Query took", logger):
for search_future in search_futures[t]: for search_future in concurrent.futures.as_completed(search_futures[t]):
if t == SearchType.Image: if t == SearchType.Image:
hits = search_futures.result() hits = await search_future.result()
output_directory = constants.web_directory / "images" output_directory = constants.web_directory / "images"
# Collate results # Collate results
results += image_search.collate_results( results += image_search.collate_results(
@ -270,7 +270,7 @@ def search(
count=results_count or 5, count=results_count or 5,
) )
else: else:
hits, entries = search_future.result() hits, entries = await search_future.result()
# Collate results # Collate results
results += text_search.collate_results(hits, entries, results_count or 5) results += text_search.collate_results(hits, entries, results_count or 5)

View file

@ -102,7 +102,7 @@ def compute_embeddings(
return corpus_embeddings return corpus_embeddings
def query( async def query(
raw_query: str, raw_query: str,
model: TextSearchModel, model: TextSearchModel,
question_embedding: torch.Tensor | None = None, question_embedding: torch.Tensor | None = None,