From 8498903641e3ec8866b1e693ceb06f6a2c8437a1 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 9 Jan 2023 16:53:18 -0300 Subject: [PATCH] Fix, add typing to Filter and TextSearchModel classes - Changes - Fix method signatures of BaseFilter subclasses. Else typing information isn't translating to them - Explicitly pass `entries: list[Entry]' as arg to `load' method - Fix type of `raw_entries' arg to `apply' method to list[Entry] from list[str] - Rename `raw_entries' arg to `apply' method to `entries' - Fix `raw_query' arg used in `apply' method of subclasses to `query' - Set type of entries, corpus_embeddings in TextSearchModel - Verification Ran `mypy --config-file .mypy.ini src' to verify typing --- src/search_filter/base_filter.py | 7 +++++-- src/search_filter/date_filter.py | 8 ++++---- src/search_filter/file_filter.py | 10 +++++----- src/search_filter/word_filter.py | 16 ++++++++-------- src/utils/config.py | 7 +++++-- 5 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/search_filter/base_filter.py b/src/search_filter/base_filter.py index 5b0ed62c..a1b56492 100644 --- a/src/search_filter/base_filter.py +++ b/src/search_filter/base_filter.py @@ -1,13 +1,16 @@ # Standard Packages from abc import ABC, abstractmethod +# Internal Packages +from src.utils.rawconfig import Entry + class BaseFilter(ABC): @abstractmethod - def load(self, *args, **kwargs): ... + def load(self, entries: list[Entry], *args, **kwargs): ... @abstractmethod def can_filter(self, raw_query:str) -> bool: ... @abstractmethod - def apply(self, query:str, raw_entries:list[str]) -> tuple[str, set[int]]: ... + def apply(self, query:str, entries: list[Entry]) -> tuple[str, set[int]]: ... diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index 00b829ac..4f1242c9 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -33,7 +33,7 @@ class DateFilter(BaseFilter): self.cache = LRU() - def load(self, entries, **_): + def load(self, entries, *args, **kwargs): start = time.time() for id, entry in enumerate(entries): # Extract dates from entry @@ -53,7 +53,7 @@ class DateFilter(BaseFilter): return self.extract_date_range(raw_query) is not None - def apply(self, query, raw_entries): + def apply(self, query, entries): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query start = time.time() @@ -63,7 +63,7 @@ class DateFilter(BaseFilter): # if no date in query, return all entries if query_daterange is None: - return query, set(range(len(raw_entries))) + return query, set(range(len(entries))) # remove date range filter from query query = re.sub(rf'\s+{self.date_regex}', ' ', query) @@ -77,7 +77,7 @@ class DateFilter(BaseFilter): return query, entries_to_include if not self.date_to_entry_ids: - self.load(raw_entries) + self.load(entries) # find entries containing any dates that fall with date range specified in query start = time.time() diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 84b520c0..95635207 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -31,12 +31,12 @@ class FileFilter(BaseFilter): def can_filter(self, raw_query): return re.search(self.file_filter_regex, raw_query) is not None - def apply(self, raw_query, raw_entries): + def apply(self, query, entries): # Extract file filters from raw query start = time.time() - raw_files_to_search = re.findall(self.file_filter_regex, raw_query) + raw_files_to_search = re.findall(self.file_filter_regex, query) if not raw_files_to_search: - return raw_query, set(range(len(raw_entries))) + return query, set(range(len(entries))) # Convert simple file filters with no path separator into regex # e.g. "file:notes.org" -> "file:.*notes.org" @@ -50,7 +50,7 @@ class FileFilter(BaseFilter): logger.debug(f"Extract files_to_search from query: {end - start} seconds") # Return item from cache if exists - query = re.sub(self.file_filter_regex, '', raw_query).strip() + query = re.sub(self.file_filter_regex, '', query).strip() cache_key = tuple(files_to_search) if cache_key in self.cache: logger.info(f"Return file filter results from cache") @@ -58,7 +58,7 @@ class FileFilter(BaseFilter): return query, included_entry_indices if not self.file_to_entry_map: - self.load(raw_entries, regenerate=False) + self.load(entries, regenerate=False) # Mark entries that contain any blocked_words for exclusion start = time.time() diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py index ff9f9ee5..e1cfea6a 100644 --- a/src/search_filter/word_filter.py +++ b/src/search_filter/word_filter.py @@ -23,7 +23,7 @@ class WordFilter(BaseFilter): self.cache = LRU() - def load(self, entries, regenerate=False): + def load(self, entries, *args, **kwargs): start = time.time() self.cache = {} # Clear cache on filter (re-)load entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' @@ -47,20 +47,20 @@ class WordFilter(BaseFilter): return len(required_words) != 0 or len(blocked_words) != 0 - def apply(self, raw_query, raw_entries): + def apply(self, query, entries): "Find entries containing required and not blocked words specified in query" # Separate natural query from required, blocked words filters start = time.time() - required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)]) - blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)]) - query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip() + required_words = set([word.lower() for word in re.findall(self.required_regex, query)]) + blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)]) + query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip() end = time.time() logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") if len(required_words) == 0 and len(blocked_words) == 0: - return query, set(range(len(raw_entries))) + return query, set(range(len(entries))) # Return item from cache if exists cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) @@ -70,12 +70,12 @@ class WordFilter(BaseFilter): return query, included_entry_indices if not self.word_to_entry_index: - self.load(raw_entries, regenerate=False) + self.load(entries, regenerate=False) start = time.time() # mark entries that contain all required_words for inclusion - entries_with_all_required_words = set(range(len(raw_entries))) + entries_with_all_required_words = set(range(len(entries))) if len(required_words) > 0: entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words]) diff --git a/src/utils/config.py b/src/utils/config.py index bfbf65d9..8ca9e526 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -3,8 +3,11 @@ from enum import Enum from dataclasses import dataclass from pathlib import Path +# External Packages +import torch + # Internal Packages -from src.utils.rawconfig import ConversationProcessorConfig +from src.utils.rawconfig import ConversationProcessorConfig, Entry from src.search_filter.base_filter import BaseFilter @@ -21,7 +24,7 @@ class ProcessorType(str, Enum): class TextSearchModel(): - def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k): + def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder