Dlopen backend 5 (#779)

Major change to the backend that allows for pluggable versions of llama.cpp/ggml. This was squashed merged from dlopen_backend_5 where the history is preserved.
This commit is contained in:
AT
2023-05-31 17:04:01 -04:00
committed by GitHub
parent f4a1f7340c
commit 48275d0dcc
22 changed files with 993 additions and 327 deletions

View File

@@ -1,79 +1,57 @@
#include "llmodel_c.h"
#include "llmodel.h"
#include <cstring>
#include <cerrno>
#include <utility>
#include "gptj.h"
#include "llamamodel.h"
#include "mpt.h"
struct LLModelWrapper {
LLModel *llModel = nullptr;
LLModel::PromptContext promptContext;
};
llmodel_model llmodel_gptj_create()
{
LLModelWrapper *wrapper = new LLModelWrapper;
wrapper->llModel = new GPTJ;
return reinterpret_cast<void*>(wrapper);
}
void llmodel_gptj_destroy(llmodel_model gptj)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(gptj);
delete wrapper->llModel;
delete wrapper;
}
thread_local static std::string last_error_message;
llmodel_model llmodel_mpt_create()
{
LLModelWrapper *wrapper = new LLModelWrapper;
wrapper->llModel = new MPT;
return reinterpret_cast<void*>(wrapper);
}
void llmodel_mpt_destroy(llmodel_model mpt)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(mpt);
delete wrapper->llModel;
delete wrapper;
}
llmodel_model llmodel_llama_create()
{
LLModelWrapper *wrapper = new LLModelWrapper;
wrapper->llModel = new LLamaModel;
return reinterpret_cast<void*>(wrapper);
}
void llmodel_llama_destroy(llmodel_model llama)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(llama);
delete wrapper->llModel;
delete wrapper;
}
llmodel_model llmodel_model_create(const char *model_path) {
auto fres = llmodel_model_create2(model_path, "auto", nullptr);
if (!fres) {
fprintf(stderr, "Invalid model file\n");
}
return fres;
}
uint32_t magic;
llmodel_model model;
FILE *f = fopen(model_path, "rb");
fread(&magic, sizeof(magic), 1, f);
llmodel_model llmodel_model_create2(const char *model_path, const char *build_variant, llmodel_error *error) {
auto wrapper = new LLModelWrapper;
llmodel_error new_error{};
if (magic == 0x67676d6c) { model = llmodel_gptj_create(); }
else if (magic == 0x67676a74) { model = llmodel_llama_create(); }
else if (magic == 0x67676d6d) { model = llmodel_mpt_create(); }
else {fprintf(stderr, "Invalid model file\n");}
fclose(f);
return model;
try {
wrapper->llModel = LLModel::construct(model_path, build_variant);
} catch (const std::exception& e) {
new_error.code = EINVAL;
last_error_message = e.what();
}
if (!wrapper->llModel) {
delete std::exchange(wrapper, nullptr);
// Get errno and error message if none
if (new_error.code == 0) {
new_error.code = errno;
last_error_message = strerror(errno);
}
// Set message pointer
new_error.message = last_error_message.c_str();
// Set error argument
if (error) *error = new_error;
}
return reinterpret_cast<llmodel_model*>(wrapper);
}
void llmodel_model_destroy(llmodel_model model) {
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
const std::type_info &modelTypeInfo = typeid(*wrapper->llModel);
if (modelTypeInfo == typeid(GPTJ)) { llmodel_gptj_destroy(model); }
if (modelTypeInfo == typeid(LLamaModel)) { llmodel_llama_destroy(model); }
if (modelTypeInfo == typeid(MPT)) { llmodel_mpt_destroy(model); }
delete wrapper->llModel;
}
bool llmodel_loadModel(llmodel_model model, const char *model_path)
@@ -84,20 +62,20 @@ bool llmodel_loadModel(llmodel_model model, const char *model_path)
bool llmodel_isModelLoaded(llmodel_model model)
{
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
return llm->isModelLoaded();
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->isModelLoaded();
}
uint64_t llmodel_get_state_size(llmodel_model model)
{
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
return llm->stateSize();
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->stateSize();
}
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest)
{
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
return llm->saveState(dest);
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->saveState(dest);
}
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
@@ -181,6 +159,6 @@ void llmodel_setThreadCount(llmodel_model model, int32_t n_threads)
int32_t llmodel_threadCount(llmodel_model model)
{
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
return llm->threadCount();
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->threadCount();
}