From 33557b1f39e64648ecc2ea1cdc062e39d915e56a Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Fri, 7 Jul 2023 12:34:12 -0400 Subject: [PATCH] Move the implementation out of llmodel class. --- gpt4all-backend/llmodel.cpp | 22 +++++++------- gpt4all-backend/llmodel.h | 60 ++++++++++++++++++++----------------- 2 files changed, 43 insertions(+), 39 deletions(-) diff --git a/gpt4all-backend/llmodel.cpp b/gpt4all-backend/llmodel.cpp index fdf3597a..5dd33535 100644 --- a/gpt4all-backend/llmodel.cpp +++ b/gpt4all-backend/llmodel.cpp @@ -41,7 +41,7 @@ static bool requires_avxonly() { #endif } -LLModel::Implementation::Implementation(Dlhandle &&dlhandle_) : dlhandle(new Dlhandle(std::move(dlhandle_))) { +LLImplementation::LLImplementation(Dlhandle &&dlhandle_) : dlhandle(new Dlhandle(std::move(dlhandle_))) { auto get_model_type = dlhandle->get("get_model_type"); assert(get_model_type); modelType = get_model_type(); @@ -54,7 +54,7 @@ LLModel::Implementation::Implementation(Dlhandle &&dlhandle_) : dlhandle(new Dlh assert(construct_); } -LLModel::Implementation::Implementation(Implementation &&o) +LLImplementation::LLImplementation(LLImplementation &&o) : construct_(o.construct_) , modelType(o.modelType) , buildVariant(o.buildVariant) @@ -63,19 +63,19 @@ LLModel::Implementation::Implementation(Implementation &&o) o.dlhandle = nullptr; } -LLModel::Implementation::~Implementation() { +LLImplementation::~LLImplementation() { if (dlhandle) delete dlhandle; } -bool LLModel::Implementation::isImplementation(const Dlhandle &dl) { +bool LLImplementation::isImplementation(const Dlhandle &dl) { return dl.get("is_g4a_backend_model_implementation"); } -const std::vector &LLModel::implementationList() { +const std::vector &LLModel::implementationList() { // NOTE: allocated on heap so we leak intentionally on exit so we have a chance to clean up the // individual models without the cleanup of the static list interfering - static auto* libs = new std::vector([] () { - std::vector fres; + static auto* libs = new std::vector([] () { + std::vector fres; auto search_in_directory = [&](const std::string& paths) { std::stringstream ss(paths); @@ -90,10 +90,10 @@ const std::vector &LLModel::implementationList() { // Add to list if model implementation try { Dlhandle dl(p.string()); - if (!Implementation::isImplementation(dl)) { + if (!LLImplementation::isImplementation(dl)) { continue; } - fres.emplace_back(Implementation(std::move(dl))); + fres.emplace_back(LLImplementation(std::move(dl))); } catch (...) {} } } @@ -107,7 +107,7 @@ const std::vector &LLModel::implementationList() { return *libs; } -const LLModel::Implementation* LLModel::implementation(std::ifstream& f, const std::string& buildVariant) { +const LLImplementation* LLModel::implementation(std::ifstream& f, const std::string& buildVariant) { for (const auto& i : implementationList()) { f.seekg(0); if (!i.magicMatch(f)) continue; @@ -126,7 +126,7 @@ LLModel *LLModel::construct(const std::string &modelPath, std::string buildVaria std::ifstream f(modelPath, std::ios::binary); if (!f) return nullptr; // Get correct implementation - const LLModel::Implementation* impl = nullptr; + const LLImplementation* impl = nullptr; #if defined(__APPLE__) && defined(__arm64__) // FIXME: See if metal works for intel macs if (buildVariant == "auto") { diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index ce7a6f57..920bc350 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -12,34 +12,11 @@ #define LLMODEL_MAX_PROMPT_BATCH 128 class Dlhandle; - +class LLImplementation; class LLModel { public: using Token = int32_t; - class Implementation { - LLModel *(*construct_)(); - - public: - Implementation(Dlhandle&&); - Implementation(const Implementation&) = delete; - Implementation(Implementation&&); - ~Implementation(); - - static bool isImplementation(const Dlhandle&); - - std::string_view modelType, buildVariant; - bool (*magicMatch)(std::ifstream& f); - Dlhandle *dlhandle; - - // The only way an implementation should be constructed - LLModel *construct() const { - auto fres = construct_(); - fres->m_implementation = this; - return fres; - } - }; - struct PromptContext { std::vector logits; // logits of current context std::vector tokens; // current tokens in the context window @@ -74,12 +51,12 @@ public: virtual void setThreadCount(int32_t /*n_threads*/) {} virtual int32_t threadCount() const { return 1; } - const Implementation& implementation() const { + const LLImplementation& implementation() const { return *m_implementation; } - static const std::vector& implementationList(); - static const Implementation *implementation(std::ifstream& f, const std::string& buildVariant); + static const std::vector& implementationList(); + static const LLImplementation *implementation(std::ifstream& f, const std::string& buildVariant); static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto"); static void setImplementationsSearchPath(const std::string& path); @@ -99,6 +76,33 @@ protected: // shared by all base classes so it isn't virtual void recalculateContext(PromptContext &promptCtx, std::function recalculate); - const Implementation *m_implementation = nullptr; + const LLImplementation *m_implementation = nullptr; + +private: + friend class LLImplementation; }; + +class LLImplementation { + LLModel *(*construct_)(); + +public: + LLImplementation(Dlhandle&&); + LLImplementation(const LLImplementation&) = delete; + LLImplementation(LLImplementation&&); + ~LLImplementation(); + + static bool isImplementation(const Dlhandle&); + + std::string_view modelType, buildVariant; + bool (*magicMatch)(std::ifstream& f); + Dlhandle *dlhandle; + + // The only way an implementation should be constructed + LLModel *construct() const { + auto fres = construct_(); + fres->m_implementation = this; + return fres; + } +}; + #endif // LLMODEL_H