mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-17 08:52:06 +00:00
python: depend on offical NVIDIA CUDA packages (#2355)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
c779d8a32d
commit
09dd3dc318
@ -28,6 +28,27 @@ if TYPE_CHECKING:
|
||||
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
|
||||
|
||||
|
||||
# Find CUDA libraries from the official packages
|
||||
cuda_found = False
|
||||
if platform.system() in ('Linux', 'Windows'):
|
||||
try:
|
||||
from nvidia import cuda_runtime, cublas
|
||||
except ImportError:
|
||||
pass # CUDA is optional
|
||||
else:
|
||||
if platform.system() == 'Linux':
|
||||
cudalib = 'lib/libcudart.so.12'
|
||||
cublaslib = 'lib/libcublas.so.12'
|
||||
else: # Windows
|
||||
cudalib = r'bin\cudart64_12.dll'
|
||||
cublaslib = r'bin\cublas64_12.dll'
|
||||
|
||||
# preload the CUDA libs so the backend can find them
|
||||
ctypes.CDLL(os.path.join(cuda_runtime.__path__[0], cudalib), mode=ctypes.RTLD_GLOBAL)
|
||||
ctypes.CDLL(os.path.join(cublas.__path__[0], cublaslib), mode=ctypes.RTLD_GLOBAL)
|
||||
cuda_found = True
|
||||
|
||||
|
||||
# TODO: provide a config file to make this more robust
|
||||
MODEL_LIB_PATH = importlib_resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" / "build"
|
||||
|
||||
@ -218,7 +239,16 @@ class LLModel:
|
||||
model = llmodel.llmodel_model_create2(self.model_path, backend.encode(), ctypes.byref(err))
|
||||
if model is None:
|
||||
s = err.value
|
||||
raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}")
|
||||
errmsg = 'null' if s is None else s.decode()
|
||||
|
||||
if (
|
||||
backend == 'cuda'
|
||||
and not cuda_found
|
||||
and errmsg.startswith('Could not find any implementations for backend')
|
||||
):
|
||||
print('WARNING: CUDA runtime libraries not found. Try `pip install "gpt4all[cuda]"`\n', file=sys.stderr)
|
||||
|
||||
raise RuntimeError(f"Unable to instantiate model: {errmsg}")
|
||||
self.model: ctypes.c_void_p | None = model
|
||||
|
||||
def __del__(self, llmodel=llmodel):
|
||||
|
@ -93,7 +93,15 @@ setup(
|
||||
'typing-extensions>=4.3.0; python_version >= "3.9" and python_version < "3.11"',
|
||||
],
|
||||
extras_require={
|
||||
'cuda': [
|
||||
'nvidia-cuda-runtime-cu12',
|
||||
'nvidia-cublas-cu12',
|
||||
],
|
||||
'all': [
|
||||
'gpt4all[cuda]; platform_system == "Windows" or platform_system == "Linux"',
|
||||
],
|
||||
'dev': [
|
||||
'gpt4all[all]',
|
||||
'pytest',
|
||||
'twine',
|
||||
'wheel',
|
||||
|
Loading…
Reference in New Issue
Block a user