Expose CLI flag to disable using GPU for offline chat model

- Offline chat models outputing gibberish when loaded onto some GPU.
  GPU support with Vulkan in GPT4All seems a bit buggy

- This change mitigates the upstream issue by allowing user to
  manually disable using GPU for offline chat

Closes #516
This commit is contained in:
Debanjum Singh Solanky 2023-10-25 17:51:46 -07:00
parent 5bb14a05a0
commit 9677eae791
4 changed files with 14 additions and 2 deletions

View file

@ -94,6 +94,7 @@ def set_state(args):
state.port = args.port
state.demo = args.demo
state.khoj_version = version("khoj-assistant")
state.chat_on_gpu = args.chat_on_gpu
def start_server(app, host=None, port=None, socket=None):

View file

@ -1,5 +1,7 @@
import logging
from khoj.utils import state
logger = logging.getLogger(__name__)
@ -13,8 +15,9 @@ def download_model(model_name: str):
# Use GPU for Chat Model, if available
try:
model = GPT4All(model_name=model_name, device="gpu")
logger.debug(f"Loaded {model_name} chat model to GPU.")
device = "gpu" if state.chat_on_gpu else "cpu"
model = GPT4All(model_name=model_name, device=device)
logger.debug(f"Loaded {model_name} chat model to {device.upper()}")
except ValueError:
model = GPT4All(model_name=model_name)
logger.debug(f"Loaded {model_name} chat model to CPU.")

View file

@ -34,10 +34,16 @@ def cli(args=None):
help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock",
)
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
parser.add_argument(
"--disable-chat-on-gpu", action="store_true", default=False, help="Disable using GPU for the offline chat model"
)
parser.add_argument("--demo", action="store_true", default=False, help="Run Khoj in demo mode")
args = parser.parse_args(args)
# Set default values for arguments
args.chat_on_gpu = not args.disable_chat_on_gpu
args.version_no = version("khoj-assistant")
if args.version:
# Show version of khoj installed and exit

View file

@ -31,6 +31,8 @@ telemetry: List[Dict[str, str]] = []
previous_query: str = None
demo: bool = False
khoj_version: str = None
chat_on_gpu: bool = True
if torch.cuda.is_available():
# Use CUDA GPU