diff --git a/llmodel/llmodel_c.cpp b/llmodel/llmodel_c.cpp index cc4ecc70..ec430dcb 100644 --- a/llmodel/llmodel_c.cpp +++ b/llmodel/llmodel_c.cpp @@ -1,2 +1,120 @@ #include "llmodel_c.h" +#include "gptj.h" +#include "llamamodel.h" + +struct LLModelWrapper { + LLModel *llModel = nullptr; + LLModel::PromptContext promptContext; +}; + +llmodel_model llmodel_gptj_create() +{ + LLModelWrapper *wrapper = new LLModelWrapper; + wrapper->llModel = new GPTJ; + return reinterpret_cast(wrapper); +} + +void llmodel_gptj_destroy(llmodel_model gptj) +{ + LLModelWrapper *wrapper = reinterpret_cast(gptj); + delete wrapper->llModel; + delete wrapper; +} + +llmodel_model llmodel_llama_create() +{ + LLModelWrapper *wrapper = new LLModelWrapper; + wrapper->llModel = new LLamaModel; + return reinterpret_cast(wrapper); +} + +void llmodel_llama_destroy(llmodel_model llama) +{ + LLModelWrapper *wrapper = reinterpret_cast(llama); + delete wrapper->llModel; + delete wrapper; +} + +bool llmodel_loadModel(llmodel_model model, const char *model_path) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->loadModel(model_path); +} + +bool llmodel_isModelLoaded(llmodel_model model) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->isModelLoaded(); +} + +// Wrapper functions for the C callbacks +bool response_wrapper(int32_t token_id, const std::string &response, void *user_data) { + llmodel_response_callback callback = reinterpret_cast(user_data); + return callback(token_id, response.c_str()); +} + +bool recalculate_wrapper(bool is_recalculating, void *user_data) { + llmodel_recalculate_callback callback = reinterpret_cast(user_data); + return callback(is_recalculating); +} + +void llmodel_prompt(llmodel_model model, const char *prompt, + llmodel_response_callback response, + llmodel_recalculate_callback recalculate, + llmodel_prompt_context *ctx) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + + // Create std::function wrappers that call the C function pointers + std::function response_func = + std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast(response)); + std::function recalc_func = + std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast(recalculate)); + + // Copy the C prompt context + wrapper->promptContext.n_past = ctx->n_past; + wrapper->promptContext.n_ctx = ctx->n_ctx; + wrapper->promptContext.n_predict = ctx->n_predict; + wrapper->promptContext.top_k = ctx->top_k; + wrapper->promptContext.top_p = ctx->top_p; + wrapper->promptContext.temp = ctx->temp; + wrapper->promptContext.n_batch = ctx->n_batch; + wrapper->promptContext.repeat_penalty = ctx->repeat_penalty; + wrapper->promptContext.repeat_last_n = ctx->repeat_last_n; + wrapper->promptContext.contextErase = ctx->context_erase; + + // Call the C++ prompt method + wrapper->llModel->prompt(prompt, response_func, recalc_func, wrapper->promptContext); + + // Update the C context by giving access to the wrappers raw pointers to std::vector data + // which involves no copies + ctx->logits = wrapper->promptContext.logits.data(); + ctx->logits_size = wrapper->promptContext.logits.size(); + ctx->tokens = wrapper->promptContext.tokens.data(); + ctx->tokens_size = wrapper->promptContext.tokens.size(); + + // Update the rest of the C prompt context + ctx->n_past = wrapper->promptContext.n_past; + ctx->n_ctx = wrapper->promptContext.n_ctx; + ctx->n_predict = wrapper->promptContext.n_predict; + ctx->top_k = wrapper->promptContext.top_k; + ctx->top_p = wrapper->promptContext.top_p; + ctx->temp = wrapper->promptContext.temp; + ctx->n_batch = wrapper->promptContext.n_batch; + ctx->repeat_penalty = wrapper->promptContext.repeat_penalty; + ctx->repeat_last_n = wrapper->promptContext.repeat_last_n; + ctx->context_erase = wrapper->promptContext.contextErase; +} + +void llmodel_setThreadCount(llmodel_model model, int32_t n_threads) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + wrapper->llModel->setThreadCount(n_threads); +} + +int32_t llmodel_threadCount(llmodel_model model) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + return wrapper->llModel->threadCount(); +} diff --git a/llmodel/llmodel_c.h b/llmodel/llmodel_c.h index 435ec0fa..b0b3fa95 100644 --- a/llmodel/llmodel_c.h +++ b/llmodel/llmodel_c.h @@ -2,6 +2,7 @@ #define LLMODEL_C_H #include +#include #include #ifdef __cplusplus @@ -15,10 +16,15 @@ typedef void *llmodel_model; /** * llmodel_prompt_context structure for holding the prompt context. + * NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the + * raw tokens pointer. Attempting to resize them or modify them in any way can lead to undefined + * behavior. */ typedef struct { float *logits; // logits of current context + size_t logits_size; // the size of the raw logits vector int32_t *tokens; // current tokens in the context window + size_t tokens_size; // the size of the raw tokens vector int32_t n_past; // number of tokens in past conversation int32_t n_ctx; // number of tokens possible in context window int32_t n_predict; // number of tokens to predict