mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Use List, Tuple, Set from typing to support Python 3.8 for khoj
Before Python 3.9, you can't directly use list, tuple, set etc for type hinting Resolves #130
This commit is contained in:
parent
14f28e3a03
commit
cba9a6a703
12 changed files with 42 additions and 37 deletions
|
@ -2,7 +2,7 @@
|
|||
import glob
|
||||
import re
|
||||
import logging
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.text_to_jsonl import TextToJsonl
|
||||
|
@ -109,7 +109,7 @@ class BeancountToJsonl(TextToJsonl):
|
|||
return entries, dict(transaction_to_file_map)
|
||||
|
||||
@staticmethod
|
||||
def convert_transactions_to_maps(parsed_entries: list[str], transaction_to_file_map) -> list[Entry]:
|
||||
def convert_transactions_to_maps(parsed_entries: List[str], transaction_to_file_map) -> List[Entry]:
|
||||
"Convert each parsed Beancount transaction into a Entry"
|
||||
entries = []
|
||||
for parsed_entry in parsed_entries:
|
||||
|
@ -120,6 +120,6 @@ class BeancountToJsonl(TextToJsonl):
|
|||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_transaction_maps_to_jsonl(entries: list[Entry]) -> str:
|
||||
def convert_transaction_maps_to_jsonl(entries: List[Entry]) -> str:
|
||||
"Convert each Beancount transaction entry to JSON and collate as JSONL"
|
||||
return ''.join([f'{entry.to_json()}\n' for entry in entries])
|
||||
|
|
|
@ -3,6 +3,7 @@ import glob
|
|||
import re
|
||||
import logging
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.text_to_jsonl import TextToJsonl
|
||||
|
@ -110,7 +111,7 @@ class MarkdownToJsonl(TextToJsonl):
|
|||
return entries, dict(entry_to_file_map)
|
||||
|
||||
@staticmethod
|
||||
def convert_markdown_entries_to_maps(parsed_entries: list[str], entry_to_file_map) -> list[Entry]:
|
||||
def convert_markdown_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
|
||||
"Convert each Markdown entries into a dictionary"
|
||||
entries = []
|
||||
for parsed_entry in parsed_entries:
|
||||
|
@ -121,6 +122,6 @@ class MarkdownToJsonl(TextToJsonl):
|
|||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_markdown_maps_to_jsonl(entries: list[Entry]):
|
||||
def convert_markdown_maps_to_jsonl(entries: List[Entry]):
|
||||
"Convert each Markdown entry to JSON and collate as JSONL"
|
||||
return ''.join([f'{entry.to_json()}\n' for entry in entries])
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import glob
|
||||
import logging
|
||||
import time
|
||||
from typing import Iterable
|
||||
from typing import Iterable, List
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.org_mode import orgnode
|
||||
|
@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class OrgToJsonl(TextToJsonl):
|
||||
# Define Functions
|
||||
def process(self, previous_entries: list[Entry]=None):
|
||||
def process(self, previous_entries: List[Entry]=None):
|
||||
# Extract required fields from config
|
||||
org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
|
||||
index_heading_entries = self.config.index_heading_entries
|
||||
|
@ -101,9 +101,9 @@ class OrgToJsonl(TextToJsonl):
|
|||
return entries, dict(entry_to_file_map)
|
||||
|
||||
@staticmethod
|
||||
def convert_org_nodes_to_entries(parsed_entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[Entry]:
|
||||
def convert_org_nodes_to_entries(parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> List[Entry]:
|
||||
"Convert Org-Mode nodes into list of Entry objects"
|
||||
entries: list[Entry] = []
|
||||
entries: List[Entry] = []
|
||||
for parsed_entry in parsed_entries:
|
||||
if not parsed_entry.hasBody and not index_heading_entries:
|
||||
# Ignore title notes i.e notes with just headings and empty body
|
||||
|
|
|
@ -37,6 +37,7 @@ import re
|
|||
import datetime
|
||||
from pathlib import Path
|
||||
from os.path import relpath
|
||||
from typing import List
|
||||
|
||||
indent_regex = re.compile(r'^ *')
|
||||
|
||||
|
@ -69,7 +70,7 @@ def makelist(filename):
|
|||
sched_date = ''
|
||||
deadline_date = ''
|
||||
logbook = list()
|
||||
nodelist: list[Orgnode] = list()
|
||||
nodelist: List[Orgnode] = list()
|
||||
property_map = dict()
|
||||
in_properties_drawer = False
|
||||
in_logbook_drawer = False
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
import hashlib
|
||||
import time
|
||||
import logging
|
||||
from typing import Callable
|
||||
from typing import Callable, List, Tuple
|
||||
from src.utils.helpers import timer
|
||||
|
||||
# Internal Packages
|
||||
|
@ -18,16 +17,16 @@ class TextToJsonl(ABC):
|
|||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def process(self, previous_entries: list[Entry]=None) -> list[tuple[int, Entry]]: ...
|
||||
def process(self, previous_entries: List[Entry]=None) -> List[Tuple[int, Entry]]: ...
|
||||
|
||||
@staticmethod
|
||||
def hash_func(key: str) -> Callable:
|
||||
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def split_entries_by_max_tokens(entries: list[Entry], max_tokens: int=256, max_word_length: int=500) -> list[Entry]:
|
||||
def split_entries_by_max_tokens(entries: List[Entry], max_tokens: int=256, max_word_length: int=500) -> List[Entry]:
|
||||
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
|
||||
chunked_entries: list[Entry] = []
|
||||
chunked_entries: List[Entry] = []
|
||||
for entry in entries:
|
||||
compiled_entry_words = entry.compiled.split()
|
||||
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
|
||||
|
@ -39,7 +38,7 @@ class TextToJsonl(ABC):
|
|||
chunked_entries.append(entry_chunk)
|
||||
return chunked_entries
|
||||
|
||||
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
|
||||
def mark_entries_for_update(self, current_entries: List[Entry], previous_entries: List[Entry], key='compiled', logger=None) -> List[Tuple[int, Entry]]:
|
||||
# Hash all current and previous entries to identify new entries
|
||||
with timer("Hash previous, current entries", logger):
|
||||
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Standard Packages
|
||||
import yaml
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter
|
||||
|
@ -38,9 +38,9 @@ async def set_config_data(updated_config: FullConfig):
|
|||
outfile.close()
|
||||
return state.config
|
||||
|
||||
@api.get('/search', response_model=list[SearchResponse])
|
||||
@api.get('/search', response_model=List[SearchResponse])
|
||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
||||
results: list[SearchResponse] = []
|
||||
results: List[SearchResponse] = []
|
||||
if q is None or q == '':
|
||||
logger.info(f'No query param (q) passed in API call to initiate search')
|
||||
return results
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.rawconfig import Entry
|
||||
|
@ -7,10 +8,10 @@ from src.utils.rawconfig import Entry
|
|||
|
||||
class BaseFilter(ABC):
|
||||
@abstractmethod
|
||||
def load(self, entries: list[Entry], *args, **kwargs): ...
|
||||
def load(self, entries: List[Entry], *args, **kwargs): ...
|
||||
|
||||
@abstractmethod
|
||||
def can_filter(self, raw_query:str) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, query:str, entries: list[Entry]) -> tuple[str, set[int]]: ...
|
||||
def apply(self, query:str, entries: List[Entry]) -> Tuple[str, Set[int]]: ...
|
||||
|
|
|
@ -5,6 +5,7 @@ import copy
|
|||
import shutil
|
||||
import time
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
# External Packages
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
|
@ -189,8 +190,8 @@ def query(raw_query, count, model: ImageSearchModel):
|
|||
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
||||
|
||||
|
||||
def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> list[SearchResponse]:
|
||||
results: list[SearchResponse] = []
|
||||
def collate_results(hits, image_names, output_directory, image_files_url, count=5) -> List[SearchResponse]:
|
||||
results: List[SearchResponse] = []
|
||||
|
||||
for index, hit in enumerate(hits[:count]):
|
||||
source_path = image_names[hit['corpus_id']]
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Type
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
@ -53,12 +53,12 @@ def initialize_model(search_config: TextSearchConfig):
|
|||
return bi_encoder, cross_encoder, top_k
|
||||
|
||||
|
||||
def extract_entries(jsonl_file) -> list[Entry]:
|
||||
def extract_entries(jsonl_file) -> List[Entry]:
|
||||
"Load entries from compressed jsonl"
|
||||
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
||||
|
||||
|
||||
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False):
|
||||
def compute_embeddings(entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
new_entries = []
|
||||
# Load pre-computed embeddings from file if exists and update them if required
|
||||
|
@ -90,7 +90,7 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: Ba
|
|||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> tuple[list[dict], list[Entry]]:
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> Tuple[List[dict], List[Entry]]:
|
||||
"Search for entries that answer the query"
|
||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||
|
||||
|
@ -127,7 +127,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
|
|||
return hits, entries
|
||||
|
||||
|
||||
def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]:
|
||||
def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]:
|
||||
return [SearchResponse.parse_obj(
|
||||
{
|
||||
"entry": entries[hit['corpus_id']].raw,
|
||||
|
@ -141,7 +141,7 @@ def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]
|
|||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel:
|
||||
def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: List[BaseFilter] = []) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
|
@ -166,7 +166,7 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
|
|||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
||||
|
||||
|
||||
def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]:
|
||||
def apply_filters(query: str, entries: List[Entry], corpus_embeddings: torch.Tensor, filters: List[BaseFilter]) -> Tuple[str, List[Entry], torch.Tensor]:
|
||||
'''Filter query, entries and embeddings before semantic search'''
|
||||
|
||||
with timer("Total Filter Time", logger, state.device):
|
||||
|
@ -186,7 +186,7 @@ def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Ten
|
|||
return query, entries, corpus_embeddings
|
||||
|
||||
|
||||
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]:
|
||||
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: List[Entry], hits: List[dict]) -> List[dict]:
|
||||
'''Score all retrieved entries using the cross-encoder'''
|
||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
||||
|
@ -199,7 +199,7 @@ def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[E
|
|||
return hits
|
||||
|
||||
|
||||
def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
|
||||
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
|
||||
'''Order results by cross-encoder score followed by bi-encoder score'''
|
||||
with timer("Rank Time", logger, state.device):
|
||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||
|
@ -208,7 +208,7 @@ def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
|
|||
return hits
|
||||
|
||||
|
||||
def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]:
|
||||
def deduplicate_results(entries: List[Entry], hits: List[dict]) -> List[dict]:
|
||||
'''Deduplicate entries by raw entry text before showing to users
|
||||
Compiled entries are split by max tokens supported by ML models.
|
||||
This can result in duplicate hits, entries shown to user.'''
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations # to avoid quoting type hints
|
|||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
@ -29,7 +29,7 @@ class ProcessorType(str, Enum):
|
|||
|
||||
|
||||
class TextSearchModel():
|
||||
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder: CrossEncoder, filters: list[BaseFilter], top_k):
|
||||
def __init__(self, entries: List[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder: CrossEncoder, filters: List[BaseFilter], top_k):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
# External Packages
|
||||
import openai
|
||||
|
@ -15,7 +16,7 @@ class BaseEncoder(ABC):
|
|||
def __init__(self, model_name: str, device: torch.device=None, **kwargs): ...
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, entries: list[str], device:torch.device=None, **kwargs) -> torch.Tensor: ...
|
||||
def encode(self, entries: List[str], device:torch.device=None, **kwargs) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class OpenAI(BaseEncoder):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Standard Packages
|
||||
import threading
|
||||
from typing import List
|
||||
from packaging import version
|
||||
|
||||
# External Packages
|
||||
|
@ -19,7 +20,7 @@ config_file: Path = None
|
|||
verbose: int = 0
|
||||
host: str = None
|
||||
port: int = None
|
||||
cli_args: list[str] = None
|
||||
cli_args: List[str] = None
|
||||
query_cache = LRU()
|
||||
search_index_lock = threading.Lock()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue