Create Abstract Base Class for Filters. Make Word, Date Filter Child of BaseFilter

This commit is contained in:
Debanjum Singh Solanky 2022-09-04 18:05:38 +03:00
parent c9f6200007
commit e4418746f2
4 changed files with 30 additions and 4 deletions

View file

@ -0,0 +1,20 @@
# Standard Packages
from abc import ABC, abstractmethod
from typing import List, Tuple
# External Packages
import torch
class BaseFilter(ABC):
@abstractmethod
def load(self, *args, **kwargs):
pass
@abstractmethod
def can_filter(self, raw_query:str) -> bool:
pass
@abstractmethod
def apply(self, query:str, raw_entries:List[str], raw_embeddings: torch.Tensor) -> Tuple[str, List[str], torch.Tensor]:
pass

View file

@ -1,7 +1,7 @@
# Standard Packages
import re
from datetime import timedelta, datetime
from dateutil.relativedelta import relativedelta, MO
from dateutil.relativedelta import relativedelta
from math import inf
from copy import deepcopy
@ -9,8 +9,11 @@ from copy import deepcopy
import torch
import dateparser as dtparse
# Internal Packages
from src.search_filter.base_filter import BaseFilter
class DateFilter:
class DateFilter(BaseFilter):
# Date Range Filter Regexes
# Example filter queries:
# - dt>="yesterday" dt<"tomorrow"

View file

@ -8,6 +8,7 @@ import logging
import torch
# Internal Packages
from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU, resolve_absolute_path
from src.utils.config import SearchType
@ -15,7 +16,7 @@ from src.utils.config import SearchType
logger = logging.getLogger(__name__)
class WordFilter:
class WordFilter(BaseFilter):
# Filter Regex
required_regex = r'\+"(\w+)" ?'
blocked_regex = r'\-"(\w+)" ?'

View file

@ -2,9 +2,11 @@
from enum import Enum
from dataclasses import dataclass
from pathlib import Path
from typing import List
# Internal Packages
from src.utils.rawconfig import ConversationProcessorConfig
from src.search_filter.base_filter import BaseFilter
class SearchType(str, Enum):
@ -20,7 +22,7 @@ class ProcessorType(str, Enum):
class TextSearchModel():
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k):
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: List[BaseFilter], top_k):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder