Make normalizing embeddings configurable

This commit is contained in:
Debanjum Singh Solanky 2023-07-16 02:16:33 -07:00
parent 1482fd4d4d
commit ad41ef3991

View file

@ -58,7 +58,11 @@ def extract_entries(jsonl_file) -> List[Entry]:
def compute_embeddings(
entries_with_ids: List[Tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file: Path, regenerate=False
entries_with_ids: List[Tuple[int, Entry]],
bi_encoder: BaseEncoder,
embeddings_file: Path,
regenerate=False,
normalize=True,
):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = []
@ -87,8 +91,11 @@ def compute_embeddings(
existing_embeddings = torch.tensor([], device=state.device)
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
if normalize:
# Normalize embeddings for faster lookup via dot product when querying
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
# Save regenerated or updated embeddings to file
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
torch.save(corpus_embeddings, embeddings_file)
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
@ -169,6 +176,7 @@ def setup(
bi_encoder: BaseEncoder,
regenerate: bool,
filters: List[BaseFilter] = [],
normalize: bool = True,
) -> TextContent:
# Map notes in text files to (compressed) JSONL formatted file
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
@ -186,7 +194,7 @@ def setup(
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate
entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate, normalize=normalize
)
for filter in filters: