Use List, Tuple, Set from typing to support Python 3.8 for khoj

Before Python 3.9, you can't directly use list, tuple, set etc for
type hinting

Resolves #130
This commit is contained in:
Debanjum Singh Solanky 2023-02-06 01:08:43 -03:00
parent 14f28e3a03
commit cba9a6a703
12 changed files with 42 additions and 37 deletions

View file

@ -2,7 +2,7 @@
import glob
import re
import logging
import time
from typing import List
# Internal Packages
from src.processor.text_to_jsonl import TextToJsonl
@ -109,7 +109,7 @@ class BeancountToJsonl(TextToJsonl):
return entries, dict(transaction_to_file_map)
@staticmethod
def convert_transactions_to_maps(parsed_entries: list[str], transaction_to_file_map) -> list[Entry]:
def convert_transactions_to_maps(parsed_entries: List[str], transaction_to_file_map) -> List[Entry]:
"Convert each parsed Beancount transaction into a Entry"
entries = []
for parsed_entry in parsed_entries:
@ -120,6 +120,6 @@ class BeancountToJsonl(TextToJsonl):
return entries
@staticmethod
def convert_transaction_maps_to_jsonl(entries: list[Entry]) -> str:
def convert_transaction_maps_to_jsonl(entries: List[Entry]) -> str:
"Convert each Beancount transaction entry to JSON and collate as JSONL"
return ''.join([f'{entry.to_json()}\n' for entry in entries])

View file

@ -3,6 +3,7 @@ import glob
import re
import logging
import time
from typing import List
# Internal Packages
from src.processor.text_to_jsonl import TextToJsonl
@ -110,7 +111,7 @@ class MarkdownToJsonl(TextToJsonl):
return entries, dict(entry_to_file_map)
@staticmethod
def convert_markdown_entries_to_maps(parsed_entries: list[str], entry_to_file_map) -> list[Entry]:
def convert_markdown_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
"Convert each Markdown entries into a dictionary"
entries = []
for parsed_entry in parsed_entries:
@ -121,6 +122,6 @@ class MarkdownToJsonl(TextToJsonl):
return entries
@staticmethod
def convert_markdown_maps_to_jsonl(entries: list[Entry]):
def convert_markdown_maps_to_jsonl(entries: List[Entry]):
"Convert each Markdown entry to JSON and collate as JSONL"
return ''.join([f'{entry.to_json()}\n' for entry in entries])

View file

@ -2,7 +2,7 @@
import glob
import logging
import time
from typing import Iterable
from typing import Iterable, List
# Internal Packages
from src.processor.org_mode import orgnode
@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
class OrgToJsonl(TextToJsonl):
# Define Functions
def process(self, previous_entries: list[Entry]=None):
def process(self, previous_entries: List[Entry]=None):
# Extract required fields from config
org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
index_heading_entries = self.config.index_heading_entries
@ -101,9 +101,9 @@ class OrgToJsonl(TextToJsonl):
return entries, dict(entry_to_file_map)
@staticmethod
def convert_org_nodes_to_entries(parsed_entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[Entry]:
def convert_org_nodes_to_entries(parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> List[Entry]:
"Convert Org-Mode nodes into list of Entry objects"
entries: list[Entry] = []
entries: List[Entry] = []
for parsed_entry in parsed_entries:
if not parsed_entry.hasBody and not index_heading_entries:
# Ignore title notes i.e notes with just headings and empty body

View file

@ -37,6 +37,7 @@ import re
import datetime
from pathlib import Path
from os.path import relpath
from typing import List
indent_regex = re.compile(r'^ *')
@ -69,7 +70,7 @@ def makelist(filename):
sched_date = ''
deadline_date = ''
logbook = list()
nodelist: list[Orgnode] = list()
nodelist: List[Orgnode] = list()
property_map = dict()
in_properties_drawer = False
in_logbook_drawer = False

View file

@ -1,9 +1,8 @@
# Standard Packages
from abc import ABC, abstractmethod
import hashlib
import time
import logging
from typing import Callable
from typing import Callable, List, Tuple
from src.utils.helpers import timer
# Internal Packages
@ -18,16 +17,16 @@ class TextToJsonl(ABC):
self.config = config
@abstractmethod
def process(self, previous_entries: list[Entry]=None) -> list[tuple[int, Entry]]: ...
def process(self, previous_entries: List[Entry]=None) -> List[Tuple[int, Entry]]: ...
@staticmethod
def hash_func(key: str) -> Callable:
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest()
@staticmethod
def split_entries_by_max_tokens(entries: list[Entry], max_tokens: int=256, max_word_length: int=500) -> list[Entry]:
def split_entries_by_max_tokens(entries: List[Entry], max_tokens: int=256, max_word_length: int=500) -> List[Entry]:
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
chunked_entries: list[Entry] = []
chunked_entries: List[Entry] = []
for entry in entries:
compiled_entry_words = entry.compiled.split()
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
@ -39,7 +38,7 @@ class TextToJsonl(ABC):
chunked_entries.append(entry_chunk)
return chunked_entries
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
def mark_entries_for_update(self, current_entries: List[Entry], previous_entries: List[Entry], key='compiled', logger=None) -> List[Tuple[int, Entry]]:
# Hash all current and previous entries to identify new entries
with timer("Hash previous, current entries", logger):
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))

View file

@ -1,7 +1,7 @@
# Standard Packages
import yaml
import logging
from typing import Optional
from typing import List, Optional
# External Packages
from fastapi import APIRouter
@ -38,9 +38,9 @@ async def set_config_data(updated_config: FullConfig):
outfile.close()
return state.config
@api.get('/search', response_model=list[SearchResponse])
@api.get('/search', response_model=List[SearchResponse])
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
results: list[SearchResponse] = []
results: List[SearchResponse] = []
if q is None or q == '':
logger.info(f'No query param (q) passed in API call to initiate search')
return results

View file

@ -1,5 +1,6 @@
# Standard Packages
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
# Internal Packages
from src.utils.rawconfig import Entry
@ -7,10 +8,10 @@ from src.utils.rawconfig import Entry
class BaseFilter(ABC):
@abstractmethod
def load(self, entries: list[Entry], *args, **kwargs): ...
def load(self, entries: List[Entry], *args, **kwargs): ...
@abstractmethod
def can_filter(self, raw_query:str) -> bool: ...
@abstractmethod
def apply(self, query:str, entries: list[Entry]) -> tuple[str, set[int]]: ...
def apply(self, query:str, entries: List[Entry]) -> Tuple[str, Set[int]]: ...

View file

@ -5,6 +5,7 @@ import copy
import shutil
import time
import logging
from typing import List
# External Packages
from sentence_transformers import SentenceTransformer, util
@ -189,8 +190,8 @@ def query(raw_query, count, model: ImageSearchModel):
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> list[SearchResponse]:
results: list[SearchResponse] = []
def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> List[SearchResponse]:
results: List[SearchResponse] = []
for index, hit in enumerate(hits[:count]):
source_path = image_names[hit['corpus_id']]

View file

@ -2,7 +2,7 @@
import logging
from pathlib import Path
import time
from typing import Type
from typing import List, Tuple, Type
# External Packages
import torch
@ -53,12 +53,12 @@ def initialize_model(search_config: TextSearchConfig):
return bi_encoder, cross_encoder, top_k
def extract_entries(jsonl_file) -> list[Entry]:
def extract_entries(jsonl_file) -> List[Entry]:
"Load entries from compressed jsonl"
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False):
def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = []
# Load pre-computed embeddings from file if exists and update them if required
@ -90,7 +90,7 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: Ba
return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> tuple[list[dict], list[Entry]]:
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
@ -127,7 +127,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
return hits, entries
def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]:
def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]:
return [SearchResponse.parse_obj(
{
"entry": entries[hit['corpus_id']].raw,
@ -141,7 +141,7 @@ def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]
in hits[0:count]]
def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel:
def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: List[BaseFilter] = []) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
@ -166,7 +166,7 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]:
def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]) -> Tuple[str, List[Entry], torch.Tensor]:
'''Filter query, entries and embeddings before semantic search'''
with timer("Total Filter Time", logger, state.device):
@ -186,7 +186,7 @@ def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Ten
return query, entries, corpus_embeddings
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]:
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]:
'''Score all retrieved entries using the cross-encoder'''
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
@ -199,7 +199,7 @@ def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[E
return hits
def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
'''Order results by cross-encoder score followed by bi-encoder score'''
with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
@ -208,7 +208,7 @@ def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
return hits
def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]:
def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]:
'''Deduplicate entries by raw entry text before showing to users
Compiled entries are split by max tokens supported by ML models.
This can result in duplicate hits, entries shown to user.'''

View file

@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints
from enum import Enum
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List
# External Packages
import torch
@ -29,7 +29,7 @@ class ProcessorType(str, Enum):
class TextSearchModel():
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder: CrossEncoder, filters: list[BaseFilter], top_k):
def __init__(self, entries: List[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder: CrossEncoder, filters: List[BaseFilter], top_k):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder

View file

@ -1,5 +1,6 @@
# Standard Packages
from abc import ABC, abstractmethod
from typing import List
# External Packages
import openai
@ -15,7 +16,7 @@ class BaseEncoder(ABC):
def __init__(self, model_name: str, device: torch.device=None, **kwargs): ...
@abstractmethod
def encode(self, entries: list[str], device:torch.device=None, **kwargs) -> torch.Tensor: ...
def encode(self, entries: List[str], device:torch.device=None, **kwargs) -> torch.Tensor: ...
class OpenAI(BaseEncoder):

View file

@ -1,5 +1,6 @@
# Standard Packages
import threading
from typing import List
from packaging import version
# External Packages
@ -19,7 +20,7 @@ config_file: Path = None
verbose: int = 0
host: str = None
port: int = None
cli_args: list[str] = None
cli_args: List[str] = None
query_cache = LRU()
search_index_lock = threading.Lock()