mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Create BaseEncoder class. Make OpenAI encoder its child. Use for typing
- Set type of all bi_encoders to BaseEncoder - Make load_model return type Union of CrossEncoder and BaseEncoder
This commit is contained in:
parent
cf7400759b
commit
e5254a8e56
4 changed files with 25 additions and 6 deletions
|
@ -13,6 +13,7 @@ from src.search_filter.base_filter import BaseFilter
|
|||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.models import BaseEncoder
|
||||
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
|
||||
from src.utils.jsonl import load_jsonl
|
||||
|
||||
|
@ -56,7 +57,7 @@ def extract_entries(jsonl_file) -> list[Entry]:
|
|||
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
||||
|
||||
|
||||
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False):
|
||||
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file, regenerate=False):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
new_entries = []
|
||||
# Load pre-computed embeddings from file if exists and update them if required
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
|||
# Internal Packages
|
||||
from src.utils.rawconfig import ConversationProcessorConfig, Entry
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils.models import BaseEncoder
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
|
@ -24,7 +25,7 @@ class ProcessorType(str, Enum):
|
|||
|
||||
|
||||
class TextSearchModel():
|
||||
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
|
||||
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder, filters: list[BaseFilter], top_k):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
|
@ -34,7 +35,7 @@ class TextSearchModel():
|
|||
|
||||
|
||||
class ImageSearchModel():
|
||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
|
||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
|
||||
self.image_encoder = image_encoder
|
||||
self.image_names = image_names
|
||||
self.image_embeddings = image_embeddings
|
||||
|
|
|
@ -7,6 +7,12 @@ from collections import OrderedDict
|
|||
from typing import Optional, Union
|
||||
import logging
|
||||
|
||||
# External Packages
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.models import BaseEncoder
|
||||
|
||||
|
||||
def is_none_or_empty(item):
|
||||
return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == ''
|
||||
|
@ -45,7 +51,7 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
|
|||
return merged_dict
|
||||
|
||||
|
||||
def load_model(model_name: str, model_type, model_dir=None, device:str=None):
|
||||
def load_model(model_name: str, model_type, model_dir=None, device:str=None) -> Union[BaseEncoder, CrossEncoder]:
|
||||
"Load model from disk or huggingface"
|
||||
# Construct model path
|
||||
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# External Packages
|
||||
import openai
|
||||
import torch
|
||||
|
@ -7,7 +10,15 @@ from tqdm import trange
|
|||
from src.utils.state import processor_config, config_file
|
||||
|
||||
|
||||
class OpenAI:
|
||||
class BaseEncoder(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, model_name: str, device: torch.device=None, **kwargs): ...
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, entries: list[str], device:torch.device=None, **kwargs) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class OpenAI(BaseEncoder):
|
||||
def __init__(self, model_name, device=None):
|
||||
self.model_name = model_name
|
||||
if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key:
|
||||
|
@ -15,7 +26,7 @@ class OpenAI:
|
|||
openai.api_key = processor_config.conversation.openai_api_key
|
||||
self.embedding_dimensions = None
|
||||
|
||||
def encode(self, entries: list[str], device=None, **kwargs):
|
||||
def encode(self, entries, device=None, **kwargs):
|
||||
embedding_tensors = []
|
||||
|
||||
for index in trange(0, len(entries)):
|
||||
|
|
Loading…
Add table
Reference in a new issue