Fix, add typing to Filter and TextSearchModel classes

- Changes
  - Fix method signatures of BaseFilter subclasses.
    Else typing information isn't translating to them
  - Explicitly pass `entries: list[Entry]' as arg to `load' method
  - Fix type of `raw_entries' arg to `apply' method
    to list[Entry] from list[str]
  - Rename `raw_entries' arg to `apply' method to `entries'
  - Fix `raw_query' arg used in `apply' method of subclasses to `query'
  - Set type of entries, corpus_embeddings in TextSearchModel

- Verification
  Ran `mypy --config-file .mypy.ini src' to verify typing
This commit is contained in:
Debanjum Singh Solanky 2023-01-09 16:53:18 -03:00
parent d40076fcd6
commit 8498903641
5 changed files with 27 additions and 21 deletions

View file

@ -1,13 +1,16 @@
# Standard Packages
from abc import ABC, abstractmethod
# Internal Packages
from src.utils.rawconfig import Entry
class BaseFilter(ABC):
@abstractmethod
def load(self, *args, **kwargs): ...
def load(self, entries: list[Entry], *args, **kwargs): ...
@abstractmethod
def can_filter(self, raw_query:str) -> bool: ...
@abstractmethod
def apply(self, query:str, raw_entries:list[str]) -> tuple[str, set[int]]: ...
def apply(self, query:str, entries: list[Entry]) -> tuple[str, set[int]]: ...

View file

@ -33,7 +33,7 @@ class DateFilter(BaseFilter):
self.cache = LRU()
def load(self, entries, **_):
def load(self, entries, *args, **kwargs):
start = time.time()
for id, entry in enumerate(entries):
# Extract dates from entry
@ -53,7 +53,7 @@ class DateFilter(BaseFilter):
return self.extract_date_range(raw_query) is not None
def apply(self, query, raw_entries):
def apply(self, query, entries):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
start = time.time()
@ -63,7 +63,7 @@ class DateFilter(BaseFilter):
# if no date in query, return all entries
if query_daterange is None:
return query, set(range(len(raw_entries)))
return query, set(range(len(entries)))
# remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
@ -77,7 +77,7 @@ class DateFilter(BaseFilter):
return query, entries_to_include
if not self.date_to_entry_ids:
self.load(raw_entries)
self.load(entries)
# find entries containing any dates that fall with date range specified in query
start = time.time()

View file

@ -31,12 +31,12 @@ class FileFilter(BaseFilter):
def can_filter(self, raw_query):
return re.search(self.file_filter_regex, raw_query) is not None
def apply(self, raw_query, raw_entries):
def apply(self, query, entries):
# Extract file filters from raw query
start = time.time()
raw_files_to_search = re.findall(self.file_filter_regex, raw_query)
raw_files_to_search = re.findall(self.file_filter_regex, query)
if not raw_files_to_search:
return raw_query, set(range(len(raw_entries)))
return query, set(range(len(entries)))
# Convert simple file filters with no path separator into regex
# e.g. "file:notes.org" -> "file:.*notes.org"
@ -50,7 +50,7 @@ class FileFilter(BaseFilter):
logger.debug(f"Extract files_to_search from query: {end - start} seconds")
# Return item from cache if exists
query = re.sub(self.file_filter_regex, '', raw_query).strip()
query = re.sub(self.file_filter_regex, '', query).strip()
cache_key = tuple(files_to_search)
if cache_key in self.cache:
logger.info(f"Return file filter results from cache")
@ -58,7 +58,7 @@ class FileFilter(BaseFilter):
return query, included_entry_indices
if not self.file_to_entry_map:
self.load(raw_entries, regenerate=False)
self.load(entries, regenerate=False)
# Mark entries that contain any blocked_words for exclusion
start = time.time()

View file

@ -23,7 +23,7 @@ class WordFilter(BaseFilter):
self.cache = LRU()
def load(self, entries, regenerate=False):
def load(self, entries, *args, **kwargs):
start = time.time()
self.cache = {} # Clear cache on filter (re-)load
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
@ -47,20 +47,20 @@ class WordFilter(BaseFilter):
return len(required_words) != 0 or len(blocked_words) != 0
def apply(self, raw_query, raw_entries):
def apply(self, query, entries):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters
start = time.time()
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip()
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
end = time.time()
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
if len(required_words) == 0 and len(blocked_words) == 0:
return query, set(range(len(raw_entries)))
return query, set(range(len(entries)))
# Return item from cache if exists
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
@ -70,12 +70,12 @@ class WordFilter(BaseFilter):
return query, included_entry_indices
if not self.word_to_entry_index:
self.load(raw_entries, regenerate=False)
self.load(entries, regenerate=False)
start = time.time()
# mark entries that contain all required_words for inclusion
entries_with_all_required_words = set(range(len(raw_entries)))
entries_with_all_required_words = set(range(len(entries)))
if len(required_words) > 0:
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])

View file

@ -3,8 +3,11 @@ from enum import Enum
from dataclasses import dataclass
from pathlib import Path
# External Packages
import torch
# Internal Packages
from src.utils.rawconfig import ConversationProcessorConfig
from src.utils.rawconfig import ConversationProcessorConfig, Entry
from src.search_filter.base_filter import BaseFilter
@ -21,7 +24,7 @@ class ProcessorType(str, Enum):
class TextSearchModel():
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder