2021-07-31 00:20:37 -07:00
#!/usr/bin/env python
import json
from sentence_transformers import SentenceTransformer , CrossEncoder , util
import time
import gzip
import os
import sys
2021-08-15 17:12:04 -07:00
import re
2021-07-31 02:58:34 -07:00
import torch
2021-07-31 03:02:43 -07:00
import argparse
import pathlib
2021-08-16 16:33:43 -07:00
from utils . helpers import get_absolute_path
2021-08-16 23:58:24 -07:00
from processor . org_mode . org_to_jsonl import org_to_jsonl
2021-07-31 00:20:37 -07:00
2021-07-31 03:02:43 -07:00
def initialize_model ( ) :
" Initialize model for assymetric semantic search. That is, where query smaller than results "
2021-08-15 23:57:22 -07:00
bi_encoder = SentenceTransformer ( ' sentence-transformers/msmarco-MiniLM-L-6-v3 ' ) # The bi-encoder encodes all entries to use for semantic search
top_k = 100 # Number of entries we want to retrieve with the bi-encoder
cross_encoder = CrossEncoder ( ' cross-encoder/ms-marco-MiniLM-L-6-v2 ' ) # The cross-encoder re-ranks the results to improve quality
2021-07-31 03:02:43 -07:00
return bi_encoder , cross_encoder , top_k
2021-07-31 00:20:37 -07:00
2021-08-16 17:15:41 -07:00
def extract_entries ( notesfile , verbose = 0 ) :
2021-07-31 03:02:43 -07:00
" Load entries from compressed jsonl "
entries = [ ]
2021-08-16 13:22:46 -07:00
with gzip . open ( get_absolute_path ( notesfile ) , ' rt ' , encoding = ' utf8 ' ) as jsonl :
2021-07-31 03:02:43 -07:00
for line in jsonl :
note = json . loads ( line . strip ( ) )
2021-08-04 18:29:22 -07:00
# Ignore title notes i.e notes with just headings and empty body
if not " Body " in note or note [ " Body " ] . strip ( ) == " " :
continue
note_string = f ' { note [ " Title " ] } \t { note [ " Tags " ] if " Tags " in note else " " } \n { note [ " Body " ] if " Body " in note else " " } '
2021-07-31 03:02:43 -07:00
entries . extend ( [ note_string ] )
2021-07-31 00:20:37 -07:00
2021-08-16 17:15:41 -07:00
if verbose > 0 :
2021-07-31 03:02:43 -07:00
print ( f " Loaded { len ( entries ) } entries from { notesfile } " )
2021-07-31 00:20:37 -07:00
2021-07-31 03:02:43 -07:00
return entries
2021-07-31 00:20:37 -07:00
2021-08-16 17:15:41 -07:00
def compute_embeddings ( entries , bi_encoder , embeddings_file , regenerate = False , verbose = 0 ) :
2021-07-31 03:02:43 -07:00
" Compute (and Save) Embeddings or Load Pre-Computed Embeddings "
# Load pre-computed embeddings from file if exists
2021-08-16 16:04:45 -07:00
if embeddings_file . exists ( ) and not regenerate :
2021-08-16 13:22:46 -07:00
corpus_embeddings = torch . load ( get_absolute_path ( embeddings_file ) )
2021-08-16 17:15:41 -07:00
if verbose > 0 :
2021-07-31 03:02:43 -07:00
print ( f " Loaded embeddings from { embeddings_file } " )
else : # Else compute the corpus_embeddings from scratch, which can take a while
corpus_embeddings = bi_encoder . encode ( entries , convert_to_tensor = True , show_progress_bar = True )
2021-08-16 13:22:46 -07:00
torch . save ( corpus_embeddings , get_absolute_path ( embeddings_file ) )
2021-08-16 17:15:41 -07:00
if verbose > 0 :
2021-07-31 03:02:43 -07:00
print ( f " Computed embeddings and save them to { embeddings_file } " )
return corpus_embeddings
2021-08-15 17:12:04 -07:00
def query_notes ( raw_query , corpus_embeddings , entries , bi_encoder , cross_encoder , top_k = 100 ) :
2021-07-31 03:02:43 -07:00
" Search all notes for entries that answer the query "
2021-08-15 17:12:04 -07:00
# Separate natural query from explicit required, blocked words filters
query = " " . join ( [ word for word in raw_query . split ( ) if not word . startswith ( " + " ) and not word . startswith ( " - " ) ] )
required_words = set ( [ word [ 1 : ] . lower ( ) for word in raw_query . split ( ) if word . startswith ( " + " ) ] )
blocked_words = set ( [ word [ 1 : ] . lower ( ) for word in raw_query . split ( ) if word . startswith ( " - " ) ] )
2021-07-31 03:02:43 -07:00
# Encode the query using the bi-encoder
2021-07-31 00:20:37 -07:00
question_embedding = bi_encoder . encode ( query , convert_to_tensor = True )
2021-07-31 03:02:43 -07:00
# Find relevant entries for the query
2021-07-31 00:20:37 -07:00
hits = util . semantic_search ( question_embedding , corpus_embeddings , top_k = top_k )
hits = hits [ 0 ] # Get the hits for the first query
2021-08-15 17:12:04 -07:00
# Filter results using explicit filters
hits = explicit_filter ( hits , entries , required_words , blocked_words )
if hits is None or len ( hits ) == 0 :
return hits
2021-07-31 03:02:43 -07:00
# Score all retrieved entries using the cross-encoder
cross_inp = [ [ query , entries [ hit [ ' corpus_id ' ] ] ] for hit in hits ]
2021-07-31 02:56:45 -07:00
cross_scores = cross_encoder . predict ( cross_inp )
2021-07-31 03:02:43 -07:00
# Store cross-encoder scores in results dictionary for ranking
2021-07-31 02:56:45 -07:00
for idx in range ( len ( cross_scores ) ) :
hits [ idx ] [ ' cross-score ' ] = cross_scores [ idx ]
2021-07-31 00:20:37 -07:00
2021-07-31 03:02:43 -07:00
# Order results by cross encoder score followed by biencoder score
hits . sort ( key = lambda x : x [ ' score ' ] , reverse = True ) # sort by biencoder score
hits . sort ( key = lambda x : x [ ' cross-score ' ] , reverse = True ) # sort by cross encoder score
2021-08-15 17:12:04 -07:00
return hits
def explicit_filter ( hits , entries , required_words , blocked_words ) :
hits_by_word_set = [ ( set ( word . lower ( )
for word
in re . split (
' ,| \ .| | \ ]| \ [ \ (| \ )| \ { | \ } ' ,
entries [ hit [ ' corpus_id ' ] ] )
if word != " " ) ,
hit )
for hit in hits ]
if len ( required_words ) == 0 and len ( blocked_words ) == 0 :
return hits
if len ( required_words ) > 0 :
return [ hit for ( words_in_entry , hit ) in hits_by_word_set
if required_words . intersection ( words_in_entry ) and not blocked_words . intersection ( words_in_entry ) ]
if len ( blocked_words ) > 0 :
return [ hit for ( words_in_entry , hit ) in hits_by_word_set
if not blocked_words . intersection ( words_in_entry ) ]
2021-07-31 03:02:43 -07:00
return hits
def render_results ( hits , entries , count = 5 , display_biencoder_results = False ) :
" Render the Results returned by Search for the Query "
if display_biencoder_results :
# Output of top hits from bi-encoder
print ( " \n ------------------------- \n " )
print ( f " Top- { count } Bi-Encoder Retrieval hits " )
hits = sorted ( hits , key = lambda x : x [ ' score ' ] , reverse = True )
for hit in hits [ 0 : count ] :
print ( f " Score: { hit [ ' score ' ] : .3f } \n ------------ \n { entries [ hit [ ' corpus_id ' ] ] } " )
2021-07-31 00:20:37 -07:00
2021-07-31 03:02:43 -07:00
# Output of top hits from re-ranker
2021-07-31 02:56:45 -07:00
print ( " \n ------------------------- \n " )
2021-07-31 03:02:43 -07:00
print ( f " Top- { count } Cross-Encoder Re-ranker hits " )
2021-07-31 02:56:45 -07:00
hits = sorted ( hits , key = lambda x : x [ ' cross-score ' ] , reverse = True )
2021-07-31 03:02:43 -07:00
for hit in hits [ 0 : count ] :
print ( f " CrossScore: { hit [ ' cross-score ' ] : .3f } \n ----------------- \n { entries [ hit [ ' corpus_id ' ] ] } " )
2021-08-16 13:54:41 -07:00
def collate_results ( hits , entries , count = 5 ) :
2021-08-15 17:50:08 -07:00
return [
{
" Entry " : entries [ hit [ ' corpus_id ' ] ] ,
" Score " : f " { hit [ ' cross-score ' ] : .3f } "
}
for hit
in hits [ 0 : count ] ]
2021-08-16 23:58:24 -07:00
def setup ( input_files , input_filter , compressed_jsonl , embeddings , regenerate = False , verbose = False ) :
# Initialize Model
bi_encoder , cross_encoder , top_k = initialize_model ( )
# Map notes in Org-Mode files to (compressed) JSONL formatted file
if not compressed_jsonl . exists ( ) or regenerate :
org_to_jsonl ( input_files , input_filter , compressed_jsonl , verbose )
# Extract Entries
entries = extract_entries ( compressed_jsonl , verbose )
# Compute or Load Embeddings
corpus_embeddings = compute_embeddings ( entries , bi_encoder , embeddings , regenerate = regenerate , verbose = verbose )
return entries , corpus_embeddings , bi_encoder , cross_encoder , top_k
2021-07-31 03:02:43 -07:00
if __name__ == ' __main__ ' :
# Setup Argument Parser
2021-08-16 13:44:42 -07:00
parser = argparse . ArgumentParser ( description = " Map Org-Mode notes into (compressed) JSONL format " )
2021-08-16 23:58:24 -07:00
parser . add_argument ( ' --input-files ' , ' -i ' , nargs = ' * ' , help = " List of org-mode files to process " )
parser . add_argument ( ' --input-filter ' , type = str , default = None , help = " Regex filter for org-mode files to process " )
parser . add_argument ( ' --compressed-jsonl ' , ' -j ' , type = pathlib . Path , default = pathlib . Path ( " .notes.jsonl.gz " ) , help = " Compressed JSONL formatted notes file to compute embeddings from " )
parser . add_argument ( ' --embeddings ' , ' -e ' , type = pathlib . Path , default = pathlib . Path ( " .notes_embeddings.pt " ) , help = " File to save/load model embeddings to/from " )
parser . add_argument ( ' --regenerate ' , action = ' store_true ' , default = False , help = " Regenerate embeddings from org-mode files. Default: false " )
2021-07-31 03:02:43 -07:00
parser . add_argument ( ' --results-count ' , ' -n ' , default = 5 , type = int , help = " Number of results to render. Default: 5 " )
parser . add_argument ( ' --interactive ' , action = ' store_true ' , default = False , help = " Interactive mode allows user to run queries on the model. Default: true " )
2021-08-16 19:16:29 -07:00
parser . add_argument ( ' --verbose ' , action = ' count ' , default = 0 , help = " Show verbose conversion logs. Default: 0 " )
2021-07-31 03:02:43 -07:00
args = parser . parse_args ( )
2021-08-16 23:58:24 -07:00
entries , corpus_embeddings , bi_encoder , cross_encoder , top_k = setup ( args . input_files , args . input_filter , args . compressed_jsonl , args . embeddings , args . regenerate , args . verbose )
2021-07-31 03:02:43 -07:00
# Run User Queries on Entries in Interactive Mode
while args . interactive :
# get query from user
user_query = input ( " Enter your query: " )
if user_query == " exit " :
exit ( 0 )
# query notes
hits = query_notes ( user_query , corpus_embeddings , entries , bi_encoder , cross_encoder , top_k )
# render results
render_results ( hits , entries , count = args . results_count )