diff --git a/src/khoj/main.py b/src/khoj/main.py index 7b1bfd7e..4c759c2a 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -94,6 +94,7 @@ def set_state(args): state.port = args.port state.demo = args.demo state.khoj_version = version("khoj-assistant") + state.chat_on_gpu = args.chat_on_gpu def start_server(app, host=None, port=None, socket=None): diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py index 2bb1fbbc..45a1158e 100644 --- a/src/khoj/processor/conversation/gpt4all/utils.py +++ b/src/khoj/processor/conversation/gpt4all/utils.py @@ -1,5 +1,7 @@ import logging +from khoj.utils import state + logger = logging.getLogger(__name__) @@ -13,8 +15,9 @@ def download_model(model_name: str): # Use GPU for Chat Model, if available try: - model = GPT4All(model_name=model_name, device="gpu") - logger.debug(f"Loaded {model_name} chat model to GPU.") + device = "gpu" if state.chat_on_gpu else "cpu" + model = GPT4All(model_name=model_name, device=device) + logger.debug(f"Loaded {model_name} chat model to {device.upper()}") except ValueError: model = GPT4All(model_name=model_name) logger.debug(f"Loaded {model_name} chat model to CPU.") diff --git a/src/khoj/utils/cli.py b/src/khoj/utils/cli.py index 7c72b101..9f129b17 100644 --- a/src/khoj/utils/cli.py +++ b/src/khoj/utils/cli.py @@ -34,10 +34,16 @@ def cli(args=None): help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock", ) parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit") + parser.add_argument( + "--disable-chat-on-gpu", action="store_true", default=False, help="Disable using GPU for the offline chat model" + ) parser.add_argument("--demo", action="store_true", default=False, help="Run Khoj in demo mode") args = parser.parse_args(args) + # Set default values for arguments + args.chat_on_gpu = not args.disable_chat_on_gpu + args.version_no = version("khoj-assistant") if args.version: # Show version of khoj installed and exit diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 5ac8a838..e9b2ca6c 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -31,6 +31,8 @@ telemetry: List[Dict[str, str]] = [] previous_query: str = None demo: bool = False khoj_version: str = None +chat_on_gpu: bool = True + if torch.cuda.is_available(): # Use CUDA GPU