From 99754970abba637e32e5f781861e3d23ac17eb83 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 15 Sep 2022 13:57:20 +0300 Subject: [PATCH] Type the /search API response to better document the response schema - Both Text, Image Search were already giving list of entry, score - This change just concretizes this change and exposes this in the API documentation (i.e OpenAPI, Swagger, Redocs) --- src/routers/api_v1_0.py | 8 ++++---- src/search_type/image_search.py | 10 +++++----- src/search_type/text_search.py | 8 ++++---- src/utils/rawconfig.py | 5 +++++ 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/routers/api_v1_0.py b/src/routers/api_v1_0.py index 616dbc17..b6dea695 100644 --- a/src/routers/api_v1_0.py +++ b/src/routers/api_v1_0.py @@ -10,7 +10,7 @@ from fastapi import APIRouter # Internal Packages from src.configure import configure_search from src.search_type import image_search, text_search -from src.utils.rawconfig import FullConfig +from src.utils.rawconfig import FullConfig, SearchResponse from src.utils.config import SearchType from src.utils import state, constants @@ -31,16 +31,16 @@ async def config_data(updated_config: FullConfig): outfile.close() return state.config -@api_v1_0.get('/search') +@api_v1_0.get('/search', response_model=list[SearchResponse]) def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): + results: list[SearchResponse] = [] if q is None or q == '': logger.info(f'No query param (q) passed in API call to initiate search') - return {} + return results # initialize variables user_query = q.strip() results_count = n - results = {} query_start, query_end, collate_start, collate_end = None, None, None, None # return cached results, if available diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index d19a063a..e04bbe49 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -15,7 +15,7 @@ import torch # Internal Packages from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model from src.utils.config import ImageSearchModel -from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig +from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse # Create Logger @@ -203,8 +203,8 @@ def render_results(hits, image_names, image_directory, count): img.show() -def collate_results(hits, image_names, output_directory, image_files_url, count=5): - results = [] +def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> list[SearchResponse]: + results: list[SearchResponse] = [] for index, hit in enumerate(hits[:count]): source_path = image_names[hit['corpus_id']] @@ -220,7 +220,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count= shutil.copy(source_path, target_path) # Add the image metadata to the results - results += [ + results += [SearchResponse.parse_obj( { "entry": f'{image_files_url}/{target_image_name}', "score": f"{hit['score']:.9f}", @@ -230,7 +230,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count= "metadata_score": f"{hit['metadata_score']:.9f}", } } - ] + )] return results diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index ff7d9c43..009f39b9 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -13,7 +13,7 @@ from src.search_filter.base_filter import BaseFilter from src.utils import state from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model from src.utils.config import TextSearchModel -from src.utils.rawconfig import TextSearchConfig, TextContentConfig +from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig from src.utils.jsonl import load_jsonl @@ -171,12 +171,12 @@ def render_results(hits, entries, count=5, display_biencoder_results=False): print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}") -def collate_results(hits, entries, count=5): - return [ +def collate_results(hits, entries, count=5) -> list[SearchResponse]: + return [SearchResponse.parse_obj( { "entry": entries[hit['corpus_id']]['raw'], "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}" - } + }) for hit in hits[0:count]] diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 2c708569..84aadc0a 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -71,3 +71,8 @@ class FullConfig(ConfigBase): content_type: Optional[ContentConfig] search_type: Optional[SearchConfig] processor: Optional[ProcessorConfig] + +class SearchResponse(ConfigBase): + entry: str + score: str + additional: Optional[dict] \ No newline at end of file