diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 892d72e7..4e578064 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -5,6 +5,7 @@ import os import platform import re import subprocess +from pathlib import Path import sys import threading from enum import Enum @@ -54,16 +55,37 @@ MODEL_LIB_PATH = importlib_resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" def load_llmodel_library(): + """ + Loads the llmodel shared library based on the current operating system. + + This function attempts to load the shared library using the appropriate file + extension for the operating system. It first tries to load the library with the + 'lib' prefix (common for macOS, Linux, and MinGW on Windows). If the file is not + found and the operating system is Windows, it attempts to load the library without + the 'lib' prefix (common for MSVC on Windows). + + Returns: + ctypes.CDLL: The loaded shared library. + + Raises: + OSError: If the shared library cannot be found. + """ + # Determine the appropriate file extension for the shared library based on the platform ext = {"Darwin": "dylib", "Linux": "so", "Windows": "dll"}[platform.system()] + # Define library names with and without the 'lib' prefix + library_name_with_lib_prefix = f"libllmodel.{ext}" + library_name_without_lib_prefix = "llmodel.dll" + base_path = MODEL_LIB_PATH + try: - # macOS, Linux, MinGW - lib = ctypes.CDLL(str(MODEL_LIB_PATH / f"libllmodel.{ext}")) - except FileNotFoundError: - if ext != 'dll': + # Attempt to load the shared library with the 'lib' prefix (common for macOS, Linux, and MinGW) + lib = ctypes.CDLL(str(base_path / library_name_with_lib_prefix)) + except OSError: # OSError is more general and includes FileNotFoundError + if ext != "dll": raise - # MSVC - lib = ctypes.CDLL(str(MODEL_LIB_PATH / "llmodel.dll")) + # For Windows (ext == 'dll'), attempt to load the shared library without the 'lib' prefix (common for MSVC) + lib = ctypes.CDLL(str(base_path / library_name_without_lib_prefix)) return lib