Type the /search API response to better document the response schema

- Both Text, Image Search were already giving list of entry, score
- This change just concretizes this change and exposes this in the API
  documentation (i.e OpenAPI, Swagger, Redocs)
This commit is contained in:
Debanjum Singh Solanky 2022-09-15 13:57:20 +03:00
parent 0521ea10d6
commit 99754970ab
4 changed files with 18 additions and 13 deletions

View file

@ -10,7 +10,7 @@ from fastapi import APIRouter
# Internal Packages # Internal Packages
from src.configure import configure_search from src.configure import configure_search
from src.search_type import image_search, text_search from src.search_type import image_search, text_search
from src.utils.rawconfig import FullConfig from src.utils.rawconfig import FullConfig, SearchResponse
from src.utils.config import SearchType from src.utils.config import SearchType
from src.utils import state, constants from src.utils import state, constants
@ -31,16 +31,16 @@ async def config_data(updated_config: FullConfig):
outfile.close() outfile.close()
return state.config return state.config
@api_v1_0.get('/search') @api_v1_0.get('/search', response_model=list[SearchResponse])
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
results: list[SearchResponse] = []
if q is None or q == '': if q is None or q == '':
logger.info(f'No query param (q) passed in API call to initiate search') logger.info(f'No query param (q) passed in API call to initiate search')
return {} return results
# initialize variables # initialize variables
user_query = q.strip() user_query = q.strip()
results_count = n results_count = n
results = {}
query_start, query_end, collate_start, collate_end = None, None, None, None query_start, query_end, collate_start, collate_end = None, None, None, None
# return cached results, if available # return cached results, if available

View file

@ -15,7 +15,7 @@ import torch
# Internal Packages # Internal Packages
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model
from src.utils.config import ImageSearchModel from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
# Create Logger # Create Logger
@ -203,8 +203,8 @@ def render_results(hits, image_names, image_directory, count):
img.show() img.show()
def collate_results(hits, image_names, output_directory, image_files_url, count=5): def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> list[SearchResponse]:
results = [] results: list[SearchResponse] = []
for index, hit in enumerate(hits[:count]): for index, hit in enumerate(hits[:count]):
source_path = image_names[hit['corpus_id']] source_path = image_names[hit['corpus_id']]
@ -220,7 +220,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
shutil.copy(source_path, target_path) shutil.copy(source_path, target_path)
# Add the image metadata to the results # Add the image metadata to the results
results += [ results += [SearchResponse.parse_obj(
{ {
"entry": f'{image_files_url}/{target_image_name}', "entry": f'{image_files_url}/{target_image_name}',
"score": f"{hit['score']:.9f}", "score": f"{hit['score']:.9f}",
@ -230,7 +230,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
"metadata_score": f"{hit['metadata_score']:.9f}", "metadata_score": f"{hit['metadata_score']:.9f}",
} }
} }
] )]
return results return results

View file

@ -13,7 +13,7 @@ from src.search_filter.base_filter import BaseFilter
from src.utils import state from src.utils import state
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
from src.utils.config import TextSearchModel from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfig, TextContentConfig from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig
from src.utils.jsonl import load_jsonl from src.utils.jsonl import load_jsonl
@ -171,12 +171,12 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}") print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}")
def collate_results(hits, entries, count=5): def collate_results(hits, entries, count=5) -> list[SearchResponse]:
return [ return [SearchResponse.parse_obj(
{ {
"entry": entries[hit['corpus_id']]['raw'], "entry": entries[hit['corpus_id']]['raw'],
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}" "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
} })
for hit for hit
in hits[0:count]] in hits[0:count]]

View file

@ -71,3 +71,8 @@ class FullConfig(ConfigBase):
content_type: Optional[ContentConfig] content_type: Optional[ContentConfig]
search_type: Optional[SearchConfig] search_type: Optional[SearchConfig]
processor: Optional[ProcessorConfig] processor: Optional[ProcessorConfig]
class SearchResponse(ConfigBase):
entry: str
score: str
additional: Optional[dict]