diff --git a/gpt4all-backend/llmodel.cpp b/gpt4all-backend/llmodel.cpp index d9300f04..b0e49808 100644 --- a/gpt4all-backend/llmodel.cpp +++ b/gpt4all-backend/llmodel.cpp @@ -41,7 +41,7 @@ static bool requires_avxonly() { #endif } -LLMImplementation::LLMImplementation(Dlhandle &&dlhandle_) +LLModel::Implementation::Implementation(Dlhandle &&dlhandle_) : m_dlhandle(new Dlhandle(std::move(dlhandle_))) { auto get_model_type = m_dlhandle->get("get_model_type"); assert(get_model_type); @@ -50,12 +50,12 @@ LLMImplementation::LLMImplementation(Dlhandle &&dlhandle_) assert(get_build_variant); m_buildVariant = get_build_variant(); m_magicMatch = m_dlhandle->get("magic_match"); - assert(magicMatch); + assert(m_magicMatch); m_construct = m_dlhandle->get("construct"); - assert(construct_); + assert(m_construct); } -LLMImplementation::LLMImplementation(LLMImplementation &&o) +LLModel::Implementation::Implementation(Implementation &&o) : m_magicMatch(o.m_magicMatch) , m_construct(o.m_construct) , m_modelType(o.m_modelType) @@ -64,19 +64,19 @@ LLMImplementation::LLMImplementation(LLMImplementation &&o) o.m_dlhandle = nullptr; } -LLMImplementation::~LLMImplementation() { +LLModel::Implementation::~Implementation() { if (m_dlhandle) delete m_dlhandle; } -bool LLMImplementation::isImplementation(const Dlhandle &dl) { +bool LLModel::Implementation::isImplementation(const Dlhandle &dl) { return dl.get("is_g4a_backend_model_implementation"); } -const std::vector &LLMImplementation::implementationList() { +const std::vector &LLModel::Implementation::implementationList() { // NOTE: allocated on heap so we leak intentionally on exit so we have a chance to clean up the // individual models without the cleanup of the static list interfering - static auto* libs = new std::vector([] () { - std::vector fres; + static auto* libs = new std::vector([] () { + std::vector fres; auto search_in_directory = [&](const std::string& paths) { std::stringstream ss(paths); @@ -91,10 +91,10 @@ const std::vector &LLMImplementation::implementationList() { // Add to list if model implementation try { Dlhandle dl(p.string()); - if (!LLMImplementation::isImplementation(dl)) { + if (!Implementation::isImplementation(dl)) { continue; } - fres.emplace_back(LLMImplementation(std::move(dl))); + fres.emplace_back(Implementation(std::move(dl))); } catch (...) {} } } @@ -108,7 +108,7 @@ const std::vector &LLMImplementation::implementationList() { return *libs; } -const LLMImplementation* LLMImplementation::implementation(std::ifstream& f, const std::string& buildVariant) { +const LLModel::Implementation* LLModel::Implementation::implementation(std::ifstream& f, const std::string& buildVariant) { for (const auto& i : implementationList()) { f.seekg(0); if (!i.m_magicMatch(f)) continue; @@ -118,7 +118,7 @@ const LLMImplementation* LLMImplementation::implementation(std::ifstream& f, con return nullptr; } -LLModel *LLMImplementation::construct(const std::string &modelPath, std::string buildVariant) { +LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant) { if (!has_at_least_minimal_hardware()) return nullptr; @@ -127,14 +127,15 @@ LLModel *LLMImplementation::construct(const std::string &modelPath, std::string std::ifstream f(modelPath, std::ios::binary); if (!f) return nullptr; // Get correct implementation - const LLMImplementation* impl = nullptr; + const Implementation* impl = nullptr; #if defined(__APPLE__) && defined(__arm64__) // FIXME: See if metal works for intel macs if (buildVariant == "auto") { size_t total_mem = getSystemTotalRAMInBytes(); impl = implementation(f, "metal"); if(impl) { - LLModel* metalimpl = impl->construct(); + LLModel* metalimpl = impl->m_construct(); + metalimpl->m_implementation = impl; size_t req_mem = metalimpl->requiredMem(modelPath); float req_to_total = (float) req_mem / (float) total_mem; // on a 16GB M2 Mac a 13B q4_0 (0.52) works for me but a 13B q4_K_M (0.55) does not @@ -168,10 +169,10 @@ LLModel *LLMImplementation::construct(const std::string &modelPath, std::string return fres; } -void LLMImplementation::setImplementationsSearchPath(const std::string& path) { +void LLModel::Implementation::setImplementationsSearchPath(const std::string& path) { s_implementations_search_path = path; } -const std::string& LLMImplementation::implementationsSearchPath() { +const std::string& LLModel::Implementation::implementationsSearchPath() { return s_implementations_search_path; } diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index a5820174..06f9d618 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -12,10 +12,35 @@ #define LLMODEL_MAX_PROMPT_BATCH 128 class Dlhandle; -class LLMImplementation; class LLModel { public: using Token = int32_t; + class Implementation { + public: + Implementation(Dlhandle&&); + Implementation(const Implementation&) = delete; + Implementation(Implementation&&); + ~Implementation(); + + std::string_view modelType() const { return m_modelType; } + std::string_view buildVariant() const { return m_buildVariant; } + + static bool isImplementation(const Dlhandle&); + static const std::vector& implementationList(); + static const Implementation *implementation(std::ifstream& f, const std::string& buildVariant); + static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto"); + static void setImplementationsSearchPath(const std::string& path); + static const std::string& implementationsSearchPath(); + + private: + bool (*m_magicMatch)(std::ifstream& f); + LLModel *(*m_construct)(); + + private: + std::string_view m_modelType; + std::string_view m_buildVariant; + Dlhandle *m_dlhandle; + }; struct PromptContext { std::vector logits; // logits of current context @@ -51,7 +76,7 @@ public: virtual void setThreadCount(int32_t /*n_threads*/) {} virtual int32_t threadCount() const { return 1; } - const LLMImplementation& implementation() const { + const Implementation& implementation() const { return *m_implementation; } @@ -69,37 +94,10 @@ protected: // shared by all base classes so it isn't virtual void recalculateContext(PromptContext &promptCtx, std::function recalculate); - const LLMImplementation *m_implementation = nullptr; + const Implementation *m_implementation = nullptr; private: friend class LLMImplementation; }; -class LLMImplementation { -public: - LLMImplementation(Dlhandle&&); - LLMImplementation(const LLMImplementation&) = delete; - LLMImplementation(LLMImplementation&&); - ~LLMImplementation(); - - std::string_view modelType() const { return m_modelType; } - std::string_view buildVariant() const { return m_buildVariant; } - - static bool isImplementation(const Dlhandle&); - static const std::vector& implementationList(); - static const LLMImplementation *implementation(std::ifstream& f, const std::string& buildVariant); - static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto"); - static void setImplementationsSearchPath(const std::string& path); - static const std::string& implementationsSearchPath(); - -private: - bool (*m_magicMatch)(std::ifstream& f); - LLModel *(*m_construct)(); - -private: - std::string_view m_modelType; - std::string_view m_buildVariant; - Dlhandle *m_dlhandle; -}; - #endif // LLMODEL_H diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index 2364e4fa..c7e13f79 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -29,7 +29,7 @@ llmodel_model llmodel_model_create2(const char *model_path, const char *build_va int error_code = 0; try { - wrapper->llModel = LLMImplementation::construct(model_path, build_variant); + wrapper->llModel = LLModel::Implementation::construct(model_path, build_variant); } catch (const std::exception& e) { error_code = EINVAL; last_error_message = e.what(); @@ -180,10 +180,10 @@ int32_t llmodel_threadCount(llmodel_model model) void llmodel_set_implementation_search_path(const char *path) { - LLMImplementation::setImplementationsSearchPath(path); + LLModel::Implementation::setImplementationsSearchPath(path); } const char *llmodel_get_implementation_search_path() { - return LLMImplementation::implementationsSearchPath().c_str(); + return LLModel::Implementation::implementationsSearchPath().c_str(); } diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index 881ea5ec..fe1db763 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -45,8 +45,8 @@ void LLModel::prompt(const std::string &prompt, if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); - std::cerr << implementation().modelType() << " ERROR: The prompt is" << embd_inp.size() << - "tokens and the context window is" << promptCtx.n_ctx << "!\n"; + std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() << + " tokens and the context window is " << promptCtx.n_ctx << "!\n"; return; } diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index fa11cdbb..37c92d53 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -244,7 +244,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) else m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "auto"); #else - m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "auto"); + m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), "auto"); #endif if (m_llModelInfo.model) { diff --git a/gpt4all-chat/llm.cpp b/gpt4all-chat/llm.cpp index ff62d43e..7953b296 100644 --- a/gpt4all-chat/llm.cpp +++ b/gpt4all-chat/llm.cpp @@ -34,7 +34,7 @@ LLM::LLM() if (directoryExists(frameworksDir)) llmodelSearchPaths += ";" + frameworksDir; #endif - LLMImplementation::setImplementationsSearchPath(llmodelSearchPaths.toStdString()); + LLModel::Implementation::setImplementationsSearchPath(llmodelSearchPaths.toStdString()); #if defined(__x86_64__) #ifndef _MSC_VER