diff --git a/gpt4all-bindings/cli/app.py b/gpt4all-bindings/cli/app.py index 6c83f29e..083fa173 100755 --- a/gpt4all-bindings/cli/app.py +++ b/gpt4all-bindings/cli/app.py @@ -59,9 +59,13 @@ def repl( int, typer.Option("--n-threads", "-t", help="Number of threads to use for chatbot"), ] = None, + device: Annotated[ + str, + typer.Option("--device", "-d", help="Device to use for chatbot, e.g. gpu, amd, nvidia, intel. Defaults to CPU."), + ] = None, ): """The CLI read-eval-print loop.""" - gpt4all_instance = GPT4All(model) + gpt4all_instance = GPT4All(model, device=device) # if threads are passed, set them if n_threads is not None: