From 5cfb1bda8967d5e7587136c4b7b45f88931d8890 Mon Sep 17 00:00:00 2001 From: Juuso Alasuutari Date: Mon, 12 Jun 2023 19:41:22 +0300 Subject: [PATCH] llmodel: add model wrapper destructor, fix mem leak in golang bindings (#862) Signed-off-by: Juuso Alasuutari --- gpt4all-backend/llmodel_c.cpp | 21 +++++++++++---------- gpt4all-bindings/golang/binding.cpp | 1 + 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index c5ab5b39..89120507 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -9,6 +9,7 @@ struct LLModelWrapper { LLModel *llModel = nullptr; LLModel::PromptContext promptContext; + ~LLModelWrapper() { delete llModel; } }; @@ -25,33 +26,33 @@ llmodel_model llmodel_model_create(const char *model_path) { llmodel_model llmodel_model_create2(const char *model_path, const char *build_variant, llmodel_error *error) { auto wrapper = new LLModelWrapper; - llmodel_error new_error{}; + int error_code = 0; try { wrapper->llModel = LLModel::construct(model_path, build_variant); } catch (const std::exception& e) { - new_error.code = EINVAL; + error_code = EINVAL; last_error_message = e.what(); } if (!wrapper->llModel) { delete std::exchange(wrapper, nullptr); // Get errno and error message if none - if (new_error.code == 0) { - new_error.code = errno; - last_error_message = strerror(errno); + if (error_code == 0) { + error_code = errno; + last_error_message = std::strerror(error_code); } - // Set message pointer - new_error.message = last_error_message.c_str(); // Set error argument - if (error) *error = new_error; + if (error) { + error->message = last_error_message.c_str(); + error->code = error_code; + } } return reinterpret_cast(wrapper); } void llmodel_model_destroy(llmodel_model model) { - LLModelWrapper *wrapper = reinterpret_cast(model); - delete wrapper->llModel; + delete reinterpret_cast(model); } bool llmodel_loadModel(llmodel_model model, const char *model_path) diff --git a/gpt4all-bindings/golang/binding.cpp b/gpt4all-bindings/golang/binding.cpp index a0bc4feb..d626dda1 100644 --- a/gpt4all-bindings/golang/binding.cpp +++ b/gpt4all-bindings/golang/binding.cpp @@ -25,6 +25,7 @@ void* load_model(const char *fname, int n_threads) { return nullptr; } if (!llmodel_loadModel(model, fname)) { + llmodel_model_destroy(model); return nullptr; }