From 7677465f23dd601b3e1fc52cc66bf4388355f76a Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 30 Jun 2022 01:32:56 +0400 Subject: [PATCH] Fix passing of device to setup method in /reload, /regenerate API - Use local variable to pass device to asymmetric.setup method via /reload, /regenerate API - Set default argument to torch.device('cpu') instead of 'cpu' to be more formal --- src/main.py | 11 ++++++++--- src/search_type/asymmetric.py | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/main.py b/src/main.py index 557dd29d..4666e30f 100644 --- a/src/main.py +++ b/src/main.py @@ -25,7 +25,6 @@ processor_config = ProcessorConfigModel() config_file = "" verbose = 0 app = FastAPI() -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") app.mount("/views", StaticFiles(directory="views"), name="views") templates = Jinja2Templates(directory="views/") @@ -53,6 +52,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): print(f'No query param (q) passed in API call to initiate search') return {} + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") user_query = q results_count = n @@ -95,6 +95,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): @app.get('/reload') def regenerate(t: Optional[SearchType] = None): global model + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") model = initialize_search(config, regenerate=False, t=t, device=device) return {'status': 'ok', 'message': 'reload completed'} @@ -102,6 +103,7 @@ def regenerate(t: Optional[SearchType] = None): @app.get('/regenerate') def regenerate(t: Optional[SearchType] = None): global model + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") model = initialize_search(config, regenerate=True, t=t, device=device) return {'status': 'ok', 'message': 'regeneration completed'} @@ -147,7 +149,7 @@ def chat(q: str): return {'status': 'ok', 'response': gpt_response} -def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None): +def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None, device=torch.device("cpu")): # Initialize Org Notes Search if (t == SearchType.Notes or t == None) and config.content_type.org: # Extract Entries, Generate Notes Embeddings @@ -241,8 +243,11 @@ if __name__ == '__main__': # Store the raw config data. config = args.config + # Set device to GPU if available + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + # Initialize the search model from Config - model = initialize_search(args.config, args.regenerate) + model = initialize_search(args.config, args.regenerate, device=device) # Initialize Processor from Config processor_config = initialize_processor(args.config) diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index aa64e128..524eb9a1 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -93,7 +93,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, device='cpu'): +def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): "Search all notes for entries that answer the query" # Separate natural query from explicit required, blocked words filters query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) @@ -180,7 +180,7 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel: +def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, device=torch.device('cpu'), verbose: bool=False) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config)