mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 19:03:01 +01:00
Fix, add typing to Filter and TextSearchModel classes
- Changes - Fix method signatures of BaseFilter subclasses. Else typing information isn't translating to them - Explicitly pass `entries: list[Entry]' as arg to `load' method - Fix type of `raw_entries' arg to `apply' method to list[Entry] from list[str] - Rename `raw_entries' arg to `apply' method to `entries' - Fix `raw_query' arg used in `apply' method of subclasses to `query' - Set type of entries, corpus_embeddings in TextSearchModel - Verification Ran `mypy --config-file .mypy.ini src' to verify typing
This commit is contained in:
parent
d40076fcd6
commit
8498903641
5 changed files with 27 additions and 21 deletions
|
@ -1,13 +1,16 @@
|
|||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.rawconfig import Entry
|
||||
|
||||
|
||||
class BaseFilter(ABC):
|
||||
@abstractmethod
|
||||
def load(self, *args, **kwargs): ...
|
||||
def load(self, entries: list[Entry], *args, **kwargs): ...
|
||||
|
||||
@abstractmethod
|
||||
def can_filter(self, raw_query:str) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, query:str, raw_entries:list[str]) -> tuple[str, set[int]]: ...
|
||||
def apply(self, query:str, entries: list[Entry]) -> tuple[str, set[int]]: ...
|
||||
|
|
|
@ -33,7 +33,7 @@ class DateFilter(BaseFilter):
|
|||
self.cache = LRU()
|
||||
|
||||
|
||||
def load(self, entries, **_):
|
||||
def load(self, entries, *args, **kwargs):
|
||||
start = time.time()
|
||||
for id, entry in enumerate(entries):
|
||||
# Extract dates from entry
|
||||
|
@ -53,7 +53,7 @@ class DateFilter(BaseFilter):
|
|||
return self.extract_date_range(raw_query) is not None
|
||||
|
||||
|
||||
def apply(self, query, raw_entries):
|
||||
def apply(self, query, entries):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
start = time.time()
|
||||
|
@ -63,7 +63,7 @@ class DateFilter(BaseFilter):
|
|||
|
||||
# if no date in query, return all entries
|
||||
if query_daterange is None:
|
||||
return query, set(range(len(raw_entries)))
|
||||
return query, set(range(len(entries)))
|
||||
|
||||
# remove date range filter from query
|
||||
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
|
||||
|
@ -77,7 +77,7 @@ class DateFilter(BaseFilter):
|
|||
return query, entries_to_include
|
||||
|
||||
if not self.date_to_entry_ids:
|
||||
self.load(raw_entries)
|
||||
self.load(entries)
|
||||
|
||||
# find entries containing any dates that fall with date range specified in query
|
||||
start = time.time()
|
||||
|
|
|
@ -31,12 +31,12 @@ class FileFilter(BaseFilter):
|
|||
def can_filter(self, raw_query):
|
||||
return re.search(self.file_filter_regex, raw_query) is not None
|
||||
|
||||
def apply(self, raw_query, raw_entries):
|
||||
def apply(self, query, entries):
|
||||
# Extract file filters from raw query
|
||||
start = time.time()
|
||||
raw_files_to_search = re.findall(self.file_filter_regex, raw_query)
|
||||
raw_files_to_search = re.findall(self.file_filter_regex, query)
|
||||
if not raw_files_to_search:
|
||||
return raw_query, set(range(len(raw_entries)))
|
||||
return query, set(range(len(entries)))
|
||||
|
||||
# Convert simple file filters with no path separator into regex
|
||||
# e.g. "file:notes.org" -> "file:.*notes.org"
|
||||
|
@ -50,7 +50,7 @@ class FileFilter(BaseFilter):
|
|||
logger.debug(f"Extract files_to_search from query: {end - start} seconds")
|
||||
|
||||
# Return item from cache if exists
|
||||
query = re.sub(self.file_filter_regex, '', raw_query).strip()
|
||||
query = re.sub(self.file_filter_regex, '', query).strip()
|
||||
cache_key = tuple(files_to_search)
|
||||
if cache_key in self.cache:
|
||||
logger.info(f"Return file filter results from cache")
|
||||
|
@ -58,7 +58,7 @@ class FileFilter(BaseFilter):
|
|||
return query, included_entry_indices
|
||||
|
||||
if not self.file_to_entry_map:
|
||||
self.load(raw_entries, regenerate=False)
|
||||
self.load(entries, regenerate=False)
|
||||
|
||||
# Mark entries that contain any blocked_words for exclusion
|
||||
start = time.time()
|
||||
|
|
|
@ -23,7 +23,7 @@ class WordFilter(BaseFilter):
|
|||
self.cache = LRU()
|
||||
|
||||
|
||||
def load(self, entries, regenerate=False):
|
||||
def load(self, entries, *args, **kwargs):
|
||||
start = time.time()
|
||||
self.cache = {} # Clear cache on filter (re-)load
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
|
||||
|
@ -47,20 +47,20 @@ class WordFilter(BaseFilter):
|
|||
return len(required_words) != 0 or len(blocked_words) != 0
|
||||
|
||||
|
||||
def apply(self, raw_query, raw_entries):
|
||||
def apply(self, query, entries):
|
||||
"Find entries containing required and not blocked words specified in query"
|
||||
# Separate natural query from required, blocked words filters
|
||||
start = time.time()
|
||||
|
||||
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
|
||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)])
|
||||
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip()
|
||||
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
||||
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return query, set(range(len(raw_entries)))
|
||||
return query, set(range(len(entries)))
|
||||
|
||||
# Return item from cache if exists
|
||||
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
|
||||
|
@ -70,12 +70,12 @@ class WordFilter(BaseFilter):
|
|||
return query, included_entry_indices
|
||||
|
||||
if not self.word_to_entry_index:
|
||||
self.load(raw_entries, regenerate=False)
|
||||
self.load(entries, regenerate=False)
|
||||
|
||||
start = time.time()
|
||||
|
||||
# mark entries that contain all required_words for inclusion
|
||||
entries_with_all_required_words = set(range(len(raw_entries)))
|
||||
entries_with_all_required_words = set(range(len(entries)))
|
||||
if len(required_words) > 0:
|
||||
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
|
||||
|
||||
|
|
|
@ -3,8 +3,11 @@ from enum import Enum
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.rawconfig import ConversationProcessorConfig
|
||||
from src.utils.rawconfig import ConversationProcessorConfig, Entry
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
|
||||
|
||||
|
@ -21,7 +24,7 @@ class ProcessorType(str, Enum):
|
|||
|
||||
|
||||
class TextSearchModel():
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
|
||||
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
|
|
Loading…
Reference in a new issue