From db7483329cfb7fbf0297c1a613c003020c664375 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 9 Jan 2023 18:37:50 -0300 Subject: [PATCH] Only import type hint packages for type checking. Avoids circular imports Use annotations from the __future__ package to avoid having to quote type hints. This import will not be required after Python 3.11 --- src/utils/config.py | 12 ++++++++---- src/utils/helpers.py | 23 ++++++++++++----------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/utils/config.py b/src/utils/config.py index ee999ba6..118f766f 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -1,15 +1,19 @@ # System Packages +from __future__ import annotations # to avoid quoting type hints from enum import Enum from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING # External Packages import torch # Internal Packages -from src.utils.rawconfig import ConversationProcessorConfig, Entry -from src.search_filter.base_filter import BaseFilter -from src.utils.models import BaseEncoder +if TYPE_CHECKING: + from sentence_transformers import CrossEncoder + from src.search_filter.base_filter import BaseFilter + from src.utils.models import BaseEncoder + from src.utils.rawconfig import ConversationProcessorConfig, Entry class SearchType(str, Enum): @@ -25,7 +29,7 @@ class ProcessorType(str, Enum): class TextSearchModel(): - def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder, filters: list[BaseFilter], top_k): + def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder: CrossEncoder, filters: list[BaseFilter], top_k): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 4dd6a78e..1bac6e81 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -1,17 +1,18 @@ # Standard Packages -from pathlib import Path -from importlib import import_module -import sys -from os.path import join -from collections import OrderedDict -from typing import Optional, Union +from __future__ import annotations # to avoid quoting type hints import logging +import sys +from collections import OrderedDict +from importlib import import_module +from os.path import join +from pathlib import Path +from typing import Optional, Union, TYPE_CHECKING -# External Packages -from sentence_transformers import CrossEncoder - -# Internal Packages -from src.utils.models import BaseEncoder +if TYPE_CHECKING: + # External Packages + from sentence_transformers import CrossEncoder + # Internal Packages + from src.utils.models import BaseEncoder def is_none_or_empty(item):