mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Type the /search API response to better document the response schema
- Both Text, Image Search were already giving list of entry, score - This change just concretizes this change and exposes this in the API documentation (i.e OpenAPI, Swagger, Redocs)
This commit is contained in:
parent
0521ea10d6
commit
99754970ab
4 changed files with 18 additions and 13 deletions
|
@ -10,7 +10,7 @@ from fastapi import APIRouter
|
|||
# Internal Packages
|
||||
from src.configure import configure_search
|
||||
from src.search_type import image_search, text_search
|
||||
from src.utils.rawconfig import FullConfig
|
||||
from src.utils.rawconfig import FullConfig, SearchResponse
|
||||
from src.utils.config import SearchType
|
||||
from src.utils import state, constants
|
||||
|
||||
|
@ -31,16 +31,16 @@ async def config_data(updated_config: FullConfig):
|
|||
outfile.close()
|
||||
return state.config
|
||||
|
||||
@api_v1_0.get('/search')
|
||||
@api_v1_0.get('/search', response_model=list[SearchResponse])
|
||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
||||
results: list[SearchResponse] = []
|
||||
if q is None or q == '':
|
||||
logger.info(f'No query param (q) passed in API call to initiate search')
|
||||
return {}
|
||||
return results
|
||||
|
||||
# initialize variables
|
||||
user_query = q.strip()
|
||||
results_count = n
|
||||
results = {}
|
||||
query_start, query_end, collate_start, collate_end = None, None, None, None
|
||||
|
||||
# return cached results, if available
|
||||
|
|
|
@ -15,7 +15,7 @@ import torch
|
|||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model
|
||||
from src.utils.config import ImageSearchModel
|
||||
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig
|
||||
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
|
||||
|
||||
|
||||
# Create Logger
|
||||
|
@ -203,8 +203,8 @@ def render_results(hits, image_names, image_directory, count):
|
|||
img.show()
|
||||
|
||||
|
||||
def collate_results(hits, image_names, output_directory, image_files_url, count=5):
|
||||
results = []
|
||||
def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> list[SearchResponse]:
|
||||
results: list[SearchResponse] = []
|
||||
|
||||
for index, hit in enumerate(hits[:count]):
|
||||
source_path = image_names[hit['corpus_id']]
|
||||
|
@ -220,7 +220,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
|||
shutil.copy(source_path, target_path)
|
||||
|
||||
# Add the image metadata to the results
|
||||
results += [
|
||||
results += [SearchResponse.parse_obj(
|
||||
{
|
||||
"entry": f'{image_files_url}/{target_image_name}',
|
||||
"score": f"{hit['score']:.9f}",
|
||||
|
@ -230,7 +230,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
|||
"metadata_score": f"{hit['metadata_score']:.9f}",
|
||||
}
|
||||
}
|
||||
]
|
||||
)]
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from src.search_filter.base_filter import BaseFilter
|
|||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
||||
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig
|
||||
from src.utils.jsonl import load_jsonl
|
||||
|
||||
|
||||
|
@ -171,12 +171,12 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
|
|||
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}")
|
||||
|
||||
|
||||
def collate_results(hits, entries, count=5):
|
||||
return [
|
||||
def collate_results(hits, entries, count=5) -> list[SearchResponse]:
|
||||
return [SearchResponse.parse_obj(
|
||||
{
|
||||
"entry": entries[hit['corpus_id']]['raw'],
|
||||
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
|
||||
}
|
||||
})
|
||||
for hit
|
||||
in hits[0:count]]
|
||||
|
||||
|
|
|
@ -71,3 +71,8 @@ class FullConfig(ConfigBase):
|
|||
content_type: Optional[ContentConfig]
|
||||
search_type: Optional[SearchConfig]
|
||||
processor: Optional[ProcessorConfig]
|
||||
|
||||
class SearchResponse(ConfigBase):
|
||||
entry: str
|
||||
score: str
|
||||
additional: Optional[dict]
|
Loading…
Reference in a new issue