mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-24 07:55:07 +01:00
Type the /search API response to better document the response schema
- Both Text, Image Search were already giving list of entry, score - This change just concretizes this change and exposes this in the API documentation (i.e OpenAPI, Swagger, Redocs)
This commit is contained in:
parent
0521ea10d6
commit
99754970ab
4 changed files with 18 additions and 13 deletions
|
@ -10,7 +10,7 @@ from fastapi import APIRouter
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.configure import configure_search
|
from src.configure import configure_search
|
||||||
from src.search_type import image_search, text_search
|
from src.search_type import image_search, text_search
|
||||||
from src.utils.rawconfig import FullConfig
|
from src.utils.rawconfig import FullConfig, SearchResponse
|
||||||
from src.utils.config import SearchType
|
from src.utils.config import SearchType
|
||||||
from src.utils import state, constants
|
from src.utils import state, constants
|
||||||
|
|
||||||
|
@ -31,16 +31,16 @@ async def config_data(updated_config: FullConfig):
|
||||||
outfile.close()
|
outfile.close()
|
||||||
return state.config
|
return state.config
|
||||||
|
|
||||||
@api_v1_0.get('/search')
|
@api_v1_0.get('/search', response_model=list[SearchResponse])
|
||||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
||||||
|
results: list[SearchResponse] = []
|
||||||
if q is None or q == '':
|
if q is None or q == '':
|
||||||
logger.info(f'No query param (q) passed in API call to initiate search')
|
logger.info(f'No query param (q) passed in API call to initiate search')
|
||||||
return {}
|
return results
|
||||||
|
|
||||||
# initialize variables
|
# initialize variables
|
||||||
user_query = q.strip()
|
user_query = q.strip()
|
||||||
results_count = n
|
results_count = n
|
||||||
results = {}
|
|
||||||
query_start, query_end, collate_start, collate_end = None, None, None, None
|
query_start, query_end, collate_start, collate_end = None, None, None, None
|
||||||
|
|
||||||
# return cached results, if available
|
# return cached results, if available
|
||||||
|
|
|
@ -15,7 +15,7 @@ import torch
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model
|
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model
|
||||||
from src.utils.config import ImageSearchModel
|
from src.utils.config import ImageSearchModel
|
||||||
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig
|
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
|
||||||
|
|
||||||
|
|
||||||
# Create Logger
|
# Create Logger
|
||||||
|
@ -203,8 +203,8 @@ def render_results(hits, image_names, image_directory, count):
|
||||||
img.show()
|
img.show()
|
||||||
|
|
||||||
|
|
||||||
def collate_results(hits, image_names, output_directory, image_files_url, count=5):
|
def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> list[SearchResponse]:
|
||||||
results = []
|
results: list[SearchResponse] = []
|
||||||
|
|
||||||
for index, hit in enumerate(hits[:count]):
|
for index, hit in enumerate(hits[:count]):
|
||||||
source_path = image_names[hit['corpus_id']]
|
source_path = image_names[hit['corpus_id']]
|
||||||
|
@ -220,7 +220,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
||||||
shutil.copy(source_path, target_path)
|
shutil.copy(source_path, target_path)
|
||||||
|
|
||||||
# Add the image metadata to the results
|
# Add the image metadata to the results
|
||||||
results += [
|
results += [SearchResponse.parse_obj(
|
||||||
{
|
{
|
||||||
"entry": f'{image_files_url}/{target_image_name}',
|
"entry": f'{image_files_url}/{target_image_name}',
|
||||||
"score": f"{hit['score']:.9f}",
|
"score": f"{hit['score']:.9f}",
|
||||||
|
@ -230,7 +230,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
||||||
"metadata_score": f"{hit['metadata_score']:.9f}",
|
"metadata_score": f"{hit['metadata_score']:.9f}",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
)]
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from src.search_filter.base_filter import BaseFilter
|
||||||
from src.utils import state
|
from src.utils import state
|
||||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
|
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
|
||||||
from src.utils.config import TextSearchModel
|
from src.utils.config import TextSearchModel
|
||||||
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig
|
||||||
from src.utils.jsonl import load_jsonl
|
from src.utils.jsonl import load_jsonl
|
||||||
|
|
||||||
|
|
||||||
|
@ -171,12 +171,12 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
|
||||||
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}")
|
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}")
|
||||||
|
|
||||||
|
|
||||||
def collate_results(hits, entries, count=5):
|
def collate_results(hits, entries, count=5) -> list[SearchResponse]:
|
||||||
return [
|
return [SearchResponse.parse_obj(
|
||||||
{
|
{
|
||||||
"entry": entries[hit['corpus_id']]['raw'],
|
"entry": entries[hit['corpus_id']]['raw'],
|
||||||
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
|
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
|
||||||
}
|
})
|
||||||
for hit
|
for hit
|
||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
||||||
|
|
|
@ -71,3 +71,8 @@ class FullConfig(ConfigBase):
|
||||||
content_type: Optional[ContentConfig]
|
content_type: Optional[ContentConfig]
|
||||||
search_type: Optional[SearchConfig]
|
search_type: Optional[SearchConfig]
|
||||||
processor: Optional[ProcessorConfig]
|
processor: Optional[ProcessorConfig]
|
||||||
|
|
||||||
|
class SearchResponse(ConfigBase):
|
||||||
|
entry: str
|
||||||
|
score: str
|
||||||
|
additional: Optional[dict]
|
Loading…
Reference in a new issue