mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Create Abstract Base Class for Filters. Make Word, Date Filter Child of BaseFilter
This commit is contained in:
parent
c9f6200007
commit
e4418746f2
4 changed files with 30 additions and 4 deletions
20
src/search_filter/base_filter.py
Normal file
20
src/search_filter/base_filter.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
||||
|
||||
class BaseFilter(ABC):
|
||||
@abstractmethod
|
||||
def load(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_filter(self, raw_query:str) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, query:str, raw_entries:List[str], raw_embeddings: torch.Tensor) -> Tuple[str, List[str], torch.Tensor]:
|
||||
pass
|
|
@ -1,7 +1,7 @@
|
|||
# Standard Packages
|
||||
import re
|
||||
from datetime import timedelta, datetime
|
||||
from dateutil.relativedelta import relativedelta, MO
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from math import inf
|
||||
from copy import deepcopy
|
||||
|
||||
|
@ -9,8 +9,11 @@ from copy import deepcopy
|
|||
import torch
|
||||
import dateparser as dtparse
|
||||
|
||||
# Internal Packages
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
|
||||
class DateFilter:
|
||||
|
||||
class DateFilter(BaseFilter):
|
||||
# Date Range Filter Regexes
|
||||
# Example filter queries:
|
||||
# - dt>="yesterday" dt<"tomorrow"
|
||||
|
|
|
@ -8,6 +8,7 @@ import logging
|
|||
import torch
|
||||
|
||||
# Internal Packages
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils.helpers import LRU, resolve_absolute_path
|
||||
from src.utils.config import SearchType
|
||||
|
||||
|
@ -15,7 +16,7 @@ from src.utils.config import SearchType
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WordFilter:
|
||||
class WordFilter(BaseFilter):
|
||||
# Filter Regex
|
||||
required_regex = r'\+"(\w+)" ?'
|
||||
blocked_regex = r'\-"(\w+)" ?'
|
||||
|
|
|
@ -2,9 +2,11 @@
|
|||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.rawconfig import ConversationProcessorConfig
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
|
@ -20,7 +22,7 @@ class ProcessorType(str, Enum):
|
|||
|
||||
|
||||
class TextSearchModel():
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k):
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: List[BaseFilter], top_k):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
|
|
Loading…
Add table
Reference in a new issue