mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Fix passing of device to setup method in /reload, /regenerate API
- Use local variable to pass device to asymmetric.setup method via /reload, /regenerate API - Set default argument to torch.device('cpu') instead of 'cpu' to be more formal
This commit is contained in:
parent
eda4b65ddb
commit
7677465f23
2 changed files with 10 additions and 5 deletions
11
src/main.py
11
src/main.py
|
@ -25,7 +25,6 @@ processor_config = ProcessorConfigModel()
|
||||||
config_file = ""
|
config_file = ""
|
||||||
verbose = 0
|
verbose = 0
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
|
|
||||||
app.mount("/views", StaticFiles(directory="views"), name="views")
|
app.mount("/views", StaticFiles(directory="views"), name="views")
|
||||||
templates = Jinja2Templates(directory="views/")
|
templates = Jinja2Templates(directory="views/")
|
||||||
|
@ -53,6 +52,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
print(f'No query param (q) passed in API call to initiate search')
|
print(f'No query param (q) passed in API call to initiate search')
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
user_query = q
|
user_query = q
|
||||||
results_count = n
|
results_count = n
|
||||||
|
|
||||||
|
@ -95,6 +95,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
@app.get('/reload')
|
@app.get('/reload')
|
||||||
def regenerate(t: Optional[SearchType] = None):
|
def regenerate(t: Optional[SearchType] = None):
|
||||||
global model
|
global model
|
||||||
|
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = initialize_search(config, regenerate=False, t=t, device=device)
|
model = initialize_search(config, regenerate=False, t=t, device=device)
|
||||||
return {'status': 'ok', 'message': 'reload completed'}
|
return {'status': 'ok', 'message': 'reload completed'}
|
||||||
|
|
||||||
|
@ -102,6 +103,7 @@ def regenerate(t: Optional[SearchType] = None):
|
||||||
@app.get('/regenerate')
|
@app.get('/regenerate')
|
||||||
def regenerate(t: Optional[SearchType] = None):
|
def regenerate(t: Optional[SearchType] = None):
|
||||||
global model
|
global model
|
||||||
|
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = initialize_search(config, regenerate=True, t=t, device=device)
|
model = initialize_search(config, regenerate=True, t=t, device=device)
|
||||||
return {'status': 'ok', 'message': 'regeneration completed'}
|
return {'status': 'ok', 'message': 'regeneration completed'}
|
||||||
|
|
||||||
|
@ -147,7 +149,7 @@ def chat(q: str):
|
||||||
return {'status': 'ok', 'response': gpt_response}
|
return {'status': 'ok', 'response': gpt_response}
|
||||||
|
|
||||||
|
|
||||||
def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None):
|
def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None, device=torch.device("cpu")):
|
||||||
# Initialize Org Notes Search
|
# Initialize Org Notes Search
|
||||||
if (t == SearchType.Notes or t == None) and config.content_type.org:
|
if (t == SearchType.Notes or t == None) and config.content_type.org:
|
||||||
# Extract Entries, Generate Notes Embeddings
|
# Extract Entries, Generate Notes Embeddings
|
||||||
|
@ -241,8 +243,11 @@ if __name__ == '__main__':
|
||||||
# Store the raw config data.
|
# Store the raw config data.
|
||||||
config = args.config
|
config = args.config
|
||||||
|
|
||||||
|
# Set device to GPU if available
|
||||||
|
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
|
||||||
# Initialize the search model from Config
|
# Initialize the search model from Config
|
||||||
model = initialize_search(args.config, args.regenerate)
|
model = initialize_search(args.config, args.regenerate, device=device)
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
processor_config = initialize_processor(args.config)
|
processor_config = initialize_processor(args.config)
|
||||||
|
|
|
@ -93,7 +93,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query: str, model: TextSearchModel, device='cpu'):
|
def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')):
|
||||||
"Search all notes for entries that answer the query"
|
"Search all notes for entries that answer the query"
|
||||||
# Separate natural query from explicit required, blocked words filters
|
# Separate natural query from explicit required, blocked words filters
|
||||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||||
|
@ -180,7 +180,7 @@ def collate_results(hits, entries, count=5):
|
||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
||||||
|
|
||||||
def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, device='cpu', verbose: bool=False) -> TextSearchModel:
|
def setup(config: TextContentConfig, search_config: AsymmetricSearchConfig, regenerate: bool, device=torch.device('cpu'), verbose: bool=False) -> TextSearchModel:
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue