Rename 'embed' key to more generic 'compiled' for jsonl extracted results

- While it's true those strings are going to be used to generated
  embeddings, the more generic term allows them to be used elsewhere as
  well

- Their main property is that they are processed, compiled for
  usage by semantic search

- Unlike the 'raw' string which contains the external representation
  of the data, as is
This commit is contained in:
Debanjum Singh Solanky 2022-07-20 20:35:50 +04:00
parent c1369233db
commit 70e70d4b15
3 changed files with 12 additions and 12 deletions

View file

@ -63,7 +63,7 @@ def extract_entries(notesfile, verbose=0):
note_string = f'{note["Title"]}' \
f'\t{note["Tags"] if "Tags" in note else ""}' \
f'\n{note["Body"] if "Body" in note else ""}'
entries.append({'embed': note_string, 'raw': note["Raw"]})
entries.append({'compiled': note_string, 'raw': note["Raw"]})
# Close File
jsonl_file.close()
@ -83,7 +83,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
print(f"Loaded embeddings from {embeddings_file}")
else: # Else compute the corpus_embeddings from scratch, which can take a while
corpus_embeddings = bi_encoder.encode([entry['embed'] for entry in entries], convert_to_tensor=True, show_progress_bar=True)
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, show_progress_bar=True)
corpus_embeddings.to(device)
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
torch.save(corpus_embeddings, get_absolute_path(embeddings_file))
@ -116,7 +116,7 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu'), fi
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
# Score all retrieved entries using the cross-encoder
cross_inp = [[query, entries[hit['corpus_id']]['embed']] for hit in hits]
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
# Store cross-encoder scores in results dictionary for ranking
@ -138,14 +138,14 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
print(f"Top-{count} Bi-Encoder Retrieval hits")
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
for hit in hits[0:count]:
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['embed']}")
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}")
# Output of top hits from re-ranker
print("\n-------------------------\n")
print(f"Top-{count} Cross-Encoder Re-ranker hits")
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
for hit in hits[0:count]:
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['embed']}")
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}")
def collate_results(hits, entries, count=5):

View file

@ -38,7 +38,7 @@ def initialize_model(search_config: SymmetricSearchConfig):
def extract_entries(notesfile, verbose=0):
"Load entries from compressed jsonl"
return [{'raw': f'{entry["Title"]}', 'embed': f'{entry["Title"]}'}
return [{'raw': f'{entry["Title"]}', 'compiled': f'{entry["Title"]}'}
for entry
in load_jsonl(notesfile, verbose=verbose)]
@ -80,7 +80,7 @@ def query(raw_query, model: TextSearchModel, filters=[]):
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k)[0]
# Score all retrieved entries using the cross-encoder
cross_inp = [[query, entries[hit['corpus_id']]['embed']] for hit in hits]
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
# Store cross-encoder scores in results dictionary for ranking
@ -102,14 +102,14 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
print(f"Top-{count} Bi-Encoder Retrieval hits")
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
for hit in hits[0:count]:
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['embed']}")
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}")
# Output of top hits from re-ranker
print("\n-------------------------\n")
print(f"Top-{count} Cross-Encoder Re-ranker hits")
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
for hit in hits[0:count]:
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['embed']}")
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}")
def collate_results(hits, entries, count=5):

View file

@ -13,9 +13,9 @@ from src.search_filter import date_filter
def test_date_filter():
embeddings = torch.randn(3, 10)
entries = [
{'embed': '', 'raw': 'Entry with no date'},
{'embed': '', 'raw': 'April Fools entry: 1984-04-01'},
{'embed': '', 'raw': 'Entry with date:1984-04-02'}]
{'compiled': '', 'raw': 'Entry with no date'},
{'compiled': '', 'raw': 'April Fools entry: 1984-04-01'},
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
q_with_no_date_filter = 'head tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_no_date_filter, entries.copy(), embeddings)