From c54c42e3fbd8d3a3acf9ab134d95951c0e7cefaa Mon Sep 17 00:00:00 2001 From: Richard Guo Date: Fri, 2 Jun 2023 10:57:21 -0400 Subject: [PATCH] fixed finding model libs --- gpt4all-backend/llmodel.cpp | 4 ++- gpt4all-backend/llmodel.h | 9 +++++++ gpt4all-backend/llmodel_c.cpp | 10 +++++++ gpt4all-backend/llmodel_c.h | 14 ++++++++++ gpt4all-bindings/python/gpt4all/pyllmodel.py | 28 ++++++-------------- 5 files changed, 44 insertions(+), 21 deletions(-) diff --git a/gpt4all-backend/llmodel.cpp b/gpt4all-backend/llmodel.cpp index e67fdac6..f3b266f5 100644 --- a/gpt4all-backend/llmodel.cpp +++ b/gpt4all-backend/llmodel.cpp @@ -9,6 +9,8 @@ #include #include +std::string LLModel::m_implementations_search_path = "."; + static bool requires_avxonly() { #ifdef __x86_64__ #ifndef _MSC_VER @@ -76,7 +78,7 @@ const std::vector &LLModel::implementationList() { }; const char *custom_impl_lookup_path = getenv("GPT4ALL_IMPLEMENTATIONS_PATH"); - search_in_directory(custom_impl_lookup_path?custom_impl_lookup_path:"."); + search_in_directory(m_implementations_search_path); #if defined(__APPLE__) search_in_directory("../../../"); #endif diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 45a3a3c2..28b53ff8 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -76,9 +76,18 @@ public: static const Implementation *implementation(std::ifstream& f, const std::string& buildVariant); static LLModel *construct(const std::string &modelPath, std::string buildVariant = "default"); + static inline void setImplementationsSearchPath(const std::string& path) { + m_implementations_search_path = path; + } + static inline const std::string& implementationsSearchPath() { + return m_implementations_search_path; + } + protected: const Implementation *m_implementation = nullptr; void recalculateContext(PromptContext &promptCtx, std::function recalculate); + static std::string m_implementations_search_path; + }; #endif // LLMODEL_H diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index a09e20dc..c5ab5b39 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -162,3 +162,13 @@ int32_t llmodel_threadCount(llmodel_model model) LLModelWrapper *wrapper = reinterpret_cast(model); return wrapper->llModel->threadCount(); } + +void llmodel_set_implementation_search_path(const char *path) +{ + LLModel::setImplementationsSearchPath(path); +} + +const char *llmodel_get_implementation_search_path() +{ + return LLModel::implementationsSearchPath().c_str(); +} diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index ebbd4782..47bac83f 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -177,6 +177,20 @@ void llmodel_setThreadCount(llmodel_model model, int32_t n_threads); */ int32_t llmodel_threadCount(llmodel_model model); +/** + * Set llmodel implementation search path. + * Default is "." + * @param path The path to the llmodel implementation shared objects. + */ +void llmodel_set_implementation_search_path(const char *path); + +/** + * Get llmodel implementation search path. + * @return The current search path; lifetime ends on next set llmodel_set_implementation_search_path() call. + */ +const char *llmodel_get_implementation_search_path(); + + #ifdef __cplusplus } #endif diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index cdc0856b..799a0153 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -19,7 +19,6 @@ class DualStreamProcessor: self.stream.flush() self.output += cleaned_text - # TODO: provide a config file to make this more robust LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\") @@ -40,31 +39,14 @@ def load_llmodel_library(): llmodel_file = "libllmodel" + '.' + c_lib_ext - model_lib_path = str(pkg_resources.resource_filename("gpt4all", \ - os.path.join(LLMODEL_PATH, f"lib*.{c_lib_ext}"))).replace("\\", "\\\\") - model_lib_dirs = glob.glob(model_lib_path) - - # model_lib_dirs = [] - # print("hello") - # print(model_lib_files) - # for lib in model_lib_files: - # if lib != llmodel_file: - # model_lib_dirs.append(str(pkg_resources.resource_filename('gpt4all', \ - # os.path.join(LLMODEL_PATH, lib))).replace("\\", "\\\\")) - llmodel_dir = str(pkg_resources.resource_filename('gpt4all', \ os.path.join(LLMODEL_PATH, llmodel_file))).replace("\\", "\\\\") - model_libs = [] - for model_dir in model_lib_dirs: - if "libllmodel" not in model_dir: - print("loading") - model_libs.append(ctypes.CDLL(model_dir, mode=ctypes.RTLD_GLOBAL)) llmodel_lib = ctypes.CDLL(llmodel_dir) - return llmodel_lib, model_libs + return llmodel_lib -llmodel, model_libs = load_llmodel_library() +llmodel = load_llmodel_library() class LLModelError(ctypes.Structure): _fields_ = [("message", ctypes.c_char_p), @@ -114,9 +96,15 @@ llmodel.llmodel_prompt.restype = None llmodel.llmodel_setThreadCount.argtypes = [ctypes.c_void_p, ctypes.c_int32] llmodel.llmodel_setThreadCount.restype = None +llmodel.llmodel_set_implementation_search_path.argtypes = [ctypes.c_char_p] +llmodel.llmodel_set_implementation_search_path.restype = None + llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p] llmodel.llmodel_threadCount.restype = ctypes.c_int32 +model_lib_path = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\") +llmodel.llmodel_set_implementation_search_path(model_lib_path.encode('utf-8')) + class LLModel: """