diff --git a/config/environment.yml b/config/environment.yml index 289d20d7..bf3bb82a 100644 --- a/config/environment.yml +++ b/config/environment.yml @@ -18,3 +18,4 @@ dependencies: - jinja2=3.1.2 - aiofiles=0.8.0 - huggingface_hub=0.8.1 + - dateparser=1.1.1 \ No newline at end of file diff --git a/src/main.py b/src/main.py index 00e68612..addd55d5 100644 --- a/src/main.py +++ b/src/main.py @@ -18,6 +18,7 @@ from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, Con from src.utils.rawconfig import FullConfig from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize from src.search_filter.explicit_filter import explicit_filter +from src.search_filter.date_filter import date_filter # Application Global State config = FullConfig() @@ -59,14 +60,14 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Notes or t == None) and model.notes_search: # query notes - hits, entries = asymmetric.query(user_query, model.notes_search, device=device, filters=[explicit_filter]) + hits, entries = asymmetric.query(user_query, model.notes_search, device=device, filters=[explicit_filter, date_filter]) # collate and return results return asymmetric.collate_results(hits, entries, results_count) if (t == SearchType.Music or t == None) and model.music_search: # query music library - hits, entries = asymmetric.query(user_query, model.music_search, device=device, filters=[explicit_filter]) + hits, entries = asymmetric.query(user_query, model.music_search, device=device, filters=[explicit_filter, date_filter]) # collate and return results return asymmetric.collate_results(hits, entries, results_count) diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py new file mode 100644 index 00000000..81febe1e --- /dev/null +++ b/src/search_filter/date_filter.py @@ -0,0 +1,146 @@ +# Standard Packages +import re +from datetime import timedelta, datetime +from dateutil.relativedelta import relativedelta, MO +from math import inf + +# External Packages +import torch +import dateparser as dtparse + + +# Date Range Filter Regexes +# Example filter queries: +# - dt>="yesterday" dt<"tomorrow" +# - dt>="last week" +# - dt:"2 years ago" +date_regex = r"dt([:><=]{1,2})\"(.*?)\"" + + +def date_filter(query, entries, embeddings): + "Find entries containing any dates that fall within date range specified in query" + # extract date range specified in date filter of query + query_daterange = extract_date_range(query) + + # if no date in query, return all entries + if query_daterange is None: + return query, entries, embeddings + + # remove date range filter from query + query = re.sub(f'\s+{date_regex}', ' ', query) + query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces + + # find entries containing any dates that fall with date range specified in query + entries_to_include = set() + for id, entry in enumerate(entries): + # Extract dates from entry + for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[1]): + # Convert date string in entry to unix timestamp + try: + date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() + except ValueError: + continue + # Check if date in entry is within date range specified in query + if query_daterange[0] <= date_in_entry < query_daterange[1]: + entries_to_include.add(id) + break + + # delete entries (and their embeddings) marked for exclusion + entries_to_exclude = set(range(len(entries))) - entries_to_include + for id in sorted(list(entries_to_exclude), reverse=True): + del entries[id] + embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) + + return query, entries, embeddings + + +def extract_date_range(query): + # find date range filter in query + date_range_matches = re.findall(date_regex, query) + + if len(date_range_matches) == 0: + return None + + # extract, parse natural dates ranges from date range filter passed in query + # e.g today maps to (start_of_day, start_of_tomorrow) + date_ranges_from_filter = [] + for (cmp, date_str) in date_range_matches: + if parse(date_str): + dt_start, dt_end = parse(date_str) + date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]] + + # Combine dates with their comparators to form date range intervals + # For e.g + # >=yesterday maps to [start_of_yesterday, inf) + # ': + date_range_considering_comparator += [[dtrange_end, inf]] + elif cmp == '>=': + date_range_considering_comparator += [[dtrange_start, inf]] + elif cmp == '<': + date_range_considering_comparator += [[0, dtrange_start]] + elif cmp == '<=': + date_range_considering_comparator += [[0, dtrange_end]] + elif cmp == '=' or cmp == ':' or cmp == '==': + date_range_considering_comparator += [[dtrange_start, dtrange_end]] + + # Combine above intervals (via AND/intersect) + # In the above example, this gives us [start_of_yesterday, start_of_tomorrow) + # This is the effective date range to filter entries by + # --- + for date_range in date_range_considering_comparator: + effective_date_range = [ + max(effective_date_range[0], date_range[0]), + min(effective_date_range[1], date_range[1])] + + if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]: + return None + else: + return effective_date_range + + +def parse(date_str, relative_base=None): + "Parse date string passed in date filter of query to datetime object" + # clean date string to handle future date parsing by date parser + future_strings = ['later', 'from now', 'from today'] + prefer_dates_from = {True: 'future', False: 'past'}[any([True for fstr in future_strings if fstr in date_str])] + clean_date_str = re.sub('|'.join(future_strings), '', date_str) + + # parse date passed in query date filter + parsed_date = dtparse.parse( + clean_date_str, + settings= { + 'RELATIVE_BASE': relative_base or datetime.now(), + 'PREFER_DAY_OF_MONTH': 'first', + 'PREFER_DATES_FROM': prefer_dates_from + }) + + if parsed_date is None: + return None + + return date_to_daterange(parsed_date, date_str) + + +def date_to_daterange(parsed_date, date_str): + "Convert parsed date to date ranges at natural granularity (day, week, month or year)" + + start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0) + + if 'year' in date_str: + return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year+1, 1, 1, 0, 0, 0)) + if 'month' in date_str: + start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0) + next_month = start_of_month + relativedelta(months=1) + return (start_of_month, next_month) + if 'week' in date_str: + # if week in date string, dateparser parses it to next week start + # so today = end of this week + start_of_week = start_of_day - timedelta(days=7) + return (start_of_week, start_of_day) + else: + next_day = start_of_day + relativedelta(days=1) + return (start_of_day, next_day) diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index 363dbd71..f913a820 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -15,11 +15,11 @@ def explicit_filter(raw_query, entries, embeddings): return query, entries, embeddings # convert each entry to a set of words + # split on fullstop, comma, colon, tab, newline or any brackets + entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' entries_by_word_set = [set(word.lower() for word - in re.split( - r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets - entry[0]) + in re.split(entry_splitter, entry[1]) if word != "") for entry in entries] diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py new file mode 100644 index 00000000..525f011e --- /dev/null +++ b/tests/test_date_filter.py @@ -0,0 +1,116 @@ +# Standard Packages +import re +from datetime import datetime +from math import inf + +# External Packages +import torch + +# Application Packages +from src.search_filter import date_filter + + +def test_date_filter(): + embeddings = torch.randn(3, 10) + entries = [ + ['', 'Entry with no date'], + ['', 'April Fools entry: 1984-04-01'], + ['', 'Entry with date:1984-04-02']] + + q_with_no_date_filter = 'head tail' + ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_no_date_filter, entries.copy(), embeddings) + assert ret_query == 'head tail' + assert len(ret_emb) == 3 + assert ret_entries == entries + + q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' + ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) + assert ret_query == 'head tail' + assert len(ret_emb) == 0 + assert ret_entries == [] + + query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' + ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + assert ret_query == 'head tail' + assert ret_entries == [entries[2]] + assert len(ret_emb) == 1 + + query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' + ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + assert ret_query == 'head tail' + assert ret_entries == [entries[1]] + assert len(ret_emb) == 1 + + query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' + ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + assert ret_query == 'head tail' + assert ret_entries == [entries[2]] + assert len(ret_emb) == 1 + + query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' + ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + assert ret_query == 'head tail' + assert ret_entries == [entries[1], entries[2]] + assert len(ret_emb) == 2 + + +def test_extract_date_range(): + assert date_filter.extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()] + assert date_filter.extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()] + assert date_filter.extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf] + assert date_filter.extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()] + + # Unparseable date filter specified in query + assert date_filter.extract_date_range('head dt:"Summer of 69" tail') == None + + # No date filter specified in query + assert date_filter.extract_date_range('head tail') == None + + # Non intersecting date ranges + assert date_filter.extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None + + +def test_parse(): + test_now = datetime(1984, 4, 1, 21, 21, 21) + + # day variations + assert date_filter.parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0)) + assert date_filter.parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0)) + assert date_filter.parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0)) + assert date_filter.parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0)) + + # week variations + assert date_filter.parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0)) + assert date_filter.parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0)) + + # month variations + assert date_filter.parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0)) + assert date_filter.parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0)) + + # year variations + assert date_filter.parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0)) + assert date_filter.parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0)) + + # specific month/date variation + assert date_filter.parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0)) + assert date_filter.parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0)) + + +def test_date_filter_regex(): + dtrange_match = re.findall(date_filter.date_regex, 'multi word head dt>"today" dt:"1984-01-01"') + assert dtrange_match == [('>', 'today'), (':', '1984-01-01')] + + dtrange_match = re.findall(date_filter.date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail') + assert dtrange_match == [('>', 'today'), (':', '1984-01-01')] + + dtrange_match = re.findall(date_filter.date_regex, 'multi word head dt>="today" dt="1984-01-01"') + assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')] + + dtrange_match = re.findall(date_filter.date_regex, 'dt<"multi word date" multi word tail') + assert dtrange_match == [('<', 'multi word date')] + + dtrange_match = re.findall(date_filter.date_regex, 'head dt<="multi word date"') + assert dtrange_match == [('<=', 'multi word date')] + + dtrange_match = re.findall(date_filter.date_regex, 'head tail') + assert dtrange_match == [] \ No newline at end of file