mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-25 23:13:06 +00:00
set n_threads in GPT4All python bindings (#1042)
* set n_threads in GPT4All * changed default n_threads to None
This commit is contained in:
parent
ae3d91476c
commit
aed7b43143
@ -22,7 +22,7 @@ class GPT4All():
|
|||||||
model: Pointer to underlying C model.
|
model: Pointer to underlying C model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download = True):
|
def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download = True, n_threads = None):
|
||||||
"""
|
"""
|
||||||
Constructor
|
Constructor
|
||||||
|
|
||||||
@ -33,12 +33,16 @@ class GPT4All():
|
|||||||
model_type: Model architecture. This argument currently does not have any functionality and is just used as
|
model_type: Model architecture. This argument currently does not have any functionality and is just used as
|
||||||
descriptive identifier for user. Default is None.
|
descriptive identifier for user. Default is None.
|
||||||
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
||||||
|
n_threads: number of CPU threads used by GPT4All. Default is None, than the number of threads are determined automatically.
|
||||||
"""
|
"""
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model = pyllmodel.LLModel()
|
self.model = pyllmodel.LLModel()
|
||||||
# Retrieve model and download if allowed
|
# Retrieve model and download if allowed
|
||||||
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
|
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
|
||||||
self.model.load_model(model_dest)
|
self.model.load_model(model_dest)
|
||||||
|
# Set n_threads
|
||||||
|
if n_threads != None:
|
||||||
|
self.model.set_thread_count(n_threads)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_models():
|
def list_models():
|
||||||
|
Loading…
Reference in New Issue
Block a user