mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Expose CLI flag to disable using GPU for offline chat model
- Offline chat models outputing gibberish when loaded onto some GPU. GPU support with Vulkan in GPT4All seems a bit buggy - This change mitigates the upstream issue by allowing user to manually disable using GPU for offline chat Closes #516
This commit is contained in:
parent
5bb14a05a0
commit
9677eae791
4 changed files with 14 additions and 2 deletions
|
@ -94,6 +94,7 @@ def set_state(args):
|
|||
state.port = args.port
|
||||
state.demo = args.demo
|
||||
state.khoj_version = version("khoj-assistant")
|
||||
state.chat_on_gpu = args.chat_on_gpu
|
||||
|
||||
|
||||
def start_server(app, host=None, port=None, socket=None):
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import logging
|
||||
|
||||
from khoj.utils import state
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -13,8 +15,9 @@ def download_model(model_name: str):
|
|||
|
||||
# Use GPU for Chat Model, if available
|
||||
try:
|
||||
model = GPT4All(model_name=model_name, device="gpu")
|
||||
logger.debug(f"Loaded {model_name} chat model to GPU.")
|
||||
device = "gpu" if state.chat_on_gpu else "cpu"
|
||||
model = GPT4All(model_name=model_name, device=device)
|
||||
logger.debug(f"Loaded {model_name} chat model to {device.upper()}")
|
||||
except ValueError:
|
||||
model = GPT4All(model_name=model_name)
|
||||
logger.debug(f"Loaded {model_name} chat model to CPU.")
|
||||
|
|
|
@ -34,10 +34,16 @@ def cli(args=None):
|
|||
help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock",
|
||||
)
|
||||
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
|
||||
parser.add_argument(
|
||||
"--disable-chat-on-gpu", action="store_true", default=False, help="Disable using GPU for the offline chat model"
|
||||
)
|
||||
parser.add_argument("--demo", action="store_true", default=False, help="Run Khoj in demo mode")
|
||||
|
||||
args = parser.parse_args(args)
|
||||
|
||||
# Set default values for arguments
|
||||
args.chat_on_gpu = not args.disable_chat_on_gpu
|
||||
|
||||
args.version_no = version("khoj-assistant")
|
||||
if args.version:
|
||||
# Show version of khoj installed and exit
|
||||
|
|
|
@ -31,6 +31,8 @@ telemetry: List[Dict[str, str]] = []
|
|||
previous_query: str = None
|
||||
demo: bool = False
|
||||
khoj_version: str = None
|
||||
chat_on_gpu: bool = True
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Use CUDA GPU
|
||||
|
|
Loading…
Reference in a new issue