Create BaseEncoder class. Make OpenAI encoder its child. Use for typing

- Set type of all bi_encoders to BaseEncoder

- Make load_model return type Union of CrossEncoder and BaseEncoder
This commit is contained in:
Debanjum Singh Solanky 2023-01-09 18:08:37 -03:00
parent cf7400759b
commit e5254a8e56
4 changed files with 25 additions and 6 deletions

View file

@ -13,6 +13,7 @@ from src.search_filter.base_filter import BaseFilter
from src.utils import state
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
from src.utils.config import TextSearchModel
from src.utils.models import BaseEncoder
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
from src.utils.jsonl import load_jsonl
@ -56,7 +57,7 @@ def extract_entries(jsonl_file) -> list[Entry]:
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False):
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = []
# Load pre-computed embeddings from file if exists and update them if required

View file

@ -9,6 +9,7 @@ import torch
# Internal Packages
from src.utils.rawconfig import ConversationProcessorConfig, Entry
from src.search_filter.base_filter import BaseFilter
from src.utils.models import BaseEncoder
class SearchType(str, Enum):
@ -24,7 +25,7 @@ class ProcessorType(str, Enum):
class TextSearchModel():
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder, filters: list[BaseFilter], top_k):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder
@ -34,7 +35,7 @@ class TextSearchModel():
class ImageSearchModel():
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
self.image_encoder = image_encoder
self.image_names = image_names
self.image_embeddings = image_embeddings

View file

@ -7,6 +7,12 @@ from collections import OrderedDict
from typing import Optional, Union
import logging
# External Packages
from sentence_transformers import CrossEncoder
# Internal Packages
from src.utils.models import BaseEncoder
def is_none_or_empty(item):
return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == ''
@ -45,7 +51,7 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict
def load_model(model_name: str, model_type, model_dir=None, device:str=None):
def load_model(model_name: str, model_type, model_dir=None, device:str=None) -> Union[BaseEncoder, CrossEncoder]:
"Load model from disk or huggingface"
# Construct model path
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None

View file

@ -1,3 +1,6 @@
# Standard Packages
from abc import ABC, abstractmethod
# External Packages
import openai
import torch
@ -7,7 +10,15 @@ from tqdm import trange
from src.utils.state import processor_config, config_file
class OpenAI:
class BaseEncoder(ABC):
@abstractmethod
def __init__(self, model_name: str, device: torch.device=None, **kwargs): ...
@abstractmethod
def encode(self, entries: list[str], device:torch.device=None, **kwargs) -> torch.Tensor: ...
class OpenAI(BaseEncoder):
def __init__(self, model_name, device=None):
self.model_name = model_name
if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key:
@ -15,7 +26,7 @@ class OpenAI:
openai.api_key = processor_config.conversation.openai_api_key
self.embedding_dimensions = None
def encode(self, entries: list[str], device=None, **kwargs):
def encode(self, entries, device=None, **kwargs):
embedding_tensors = []
for index in trange(0, len(entries)):