mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-06 02:50:36 +00:00
Dlopen backend 5 (#779)
Major change to the backend that allows for pluggable versions of llama.cpp/ggml. This was squashed merged from dlopen_backend_5 where the history is preserved.
This commit is contained in:
@@ -1,79 +1,57 @@
|
||||
#include "llmodel_c.h"
|
||||
#include "llmodel.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <cerrno>
|
||||
#include <utility>
|
||||
|
||||
#include "gptj.h"
|
||||
#include "llamamodel.h"
|
||||
#include "mpt.h"
|
||||
|
||||
struct LLModelWrapper {
|
||||
LLModel *llModel = nullptr;
|
||||
LLModel::PromptContext promptContext;
|
||||
};
|
||||
|
||||
llmodel_model llmodel_gptj_create()
|
||||
{
|
||||
LLModelWrapper *wrapper = new LLModelWrapper;
|
||||
wrapper->llModel = new GPTJ;
|
||||
return reinterpret_cast<void*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_gptj_destroy(llmodel_model gptj)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(gptj);
|
||||
delete wrapper->llModel;
|
||||
delete wrapper;
|
||||
}
|
||||
thread_local static std::string last_error_message;
|
||||
|
||||
llmodel_model llmodel_mpt_create()
|
||||
{
|
||||
LLModelWrapper *wrapper = new LLModelWrapper;
|
||||
wrapper->llModel = new MPT;
|
||||
return reinterpret_cast<void*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_mpt_destroy(llmodel_model mpt)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(mpt);
|
||||
delete wrapper->llModel;
|
||||
delete wrapper;
|
||||
}
|
||||
|
||||
llmodel_model llmodel_llama_create()
|
||||
{
|
||||
LLModelWrapper *wrapper = new LLModelWrapper;
|
||||
wrapper->llModel = new LLamaModel;
|
||||
return reinterpret_cast<void*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_llama_destroy(llmodel_model llama)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(llama);
|
||||
delete wrapper->llModel;
|
||||
delete wrapper;
|
||||
}
|
||||
|
||||
llmodel_model llmodel_model_create(const char *model_path) {
|
||||
auto fres = llmodel_model_create2(model_path, "auto", nullptr);
|
||||
if (!fres) {
|
||||
fprintf(stderr, "Invalid model file\n");
|
||||
}
|
||||
return fres;
|
||||
}
|
||||
|
||||
uint32_t magic;
|
||||
llmodel_model model;
|
||||
FILE *f = fopen(model_path, "rb");
|
||||
fread(&magic, sizeof(magic), 1, f);
|
||||
llmodel_model llmodel_model_create2(const char *model_path, const char *build_variant, llmodel_error *error) {
|
||||
auto wrapper = new LLModelWrapper;
|
||||
llmodel_error new_error{};
|
||||
|
||||
if (magic == 0x67676d6c) { model = llmodel_gptj_create(); }
|
||||
else if (magic == 0x67676a74) { model = llmodel_llama_create(); }
|
||||
else if (magic == 0x67676d6d) { model = llmodel_mpt_create(); }
|
||||
else {fprintf(stderr, "Invalid model file\n");}
|
||||
fclose(f);
|
||||
return model;
|
||||
try {
|
||||
wrapper->llModel = LLModel::construct(model_path, build_variant);
|
||||
} catch (const std::exception& e) {
|
||||
new_error.code = EINVAL;
|
||||
last_error_message = e.what();
|
||||
}
|
||||
|
||||
if (!wrapper->llModel) {
|
||||
delete std::exchange(wrapper, nullptr);
|
||||
// Get errno and error message if none
|
||||
if (new_error.code == 0) {
|
||||
new_error.code = errno;
|
||||
last_error_message = strerror(errno);
|
||||
}
|
||||
// Set message pointer
|
||||
new_error.message = last_error_message.c_str();
|
||||
// Set error argument
|
||||
if (error) *error = new_error;
|
||||
}
|
||||
return reinterpret_cast<llmodel_model*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_model_destroy(llmodel_model model) {
|
||||
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
const std::type_info &modelTypeInfo = typeid(*wrapper->llModel);
|
||||
|
||||
if (modelTypeInfo == typeid(GPTJ)) { llmodel_gptj_destroy(model); }
|
||||
if (modelTypeInfo == typeid(LLamaModel)) { llmodel_llama_destroy(model); }
|
||||
if (modelTypeInfo == typeid(MPT)) { llmodel_mpt_destroy(model); }
|
||||
delete wrapper->llModel;
|
||||
}
|
||||
|
||||
bool llmodel_loadModel(llmodel_model model, const char *model_path)
|
||||
@@ -84,20 +62,20 @@ bool llmodel_loadModel(llmodel_model model, const char *model_path)
|
||||
|
||||
bool llmodel_isModelLoaded(llmodel_model model)
|
||||
{
|
||||
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
|
||||
return llm->isModelLoaded();
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->isModelLoaded();
|
||||
}
|
||||
|
||||
uint64_t llmodel_get_state_size(llmodel_model model)
|
||||
{
|
||||
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
|
||||
return llm->stateSize();
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->stateSize();
|
||||
}
|
||||
|
||||
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest)
|
||||
{
|
||||
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
|
||||
return llm->saveState(dest);
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->saveState(dest);
|
||||
}
|
||||
|
||||
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
|
||||
@@ -181,6 +159,6 @@ void llmodel_setThreadCount(llmodel_model model, int32_t n_threads)
|
||||
|
||||
int32_t llmodel_threadCount(llmodel_model model)
|
||||
{
|
||||
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
|
||||
return llm->threadCount();
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->threadCount();
|
||||
}
|
||||
|
Reference in New Issue
Block a user