mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-08 11:58:53 +00:00
Move the llmodel C API to new top-level directory and version it.
This commit is contained in:
55
gpt4all-backend/CMakeLists.txt
Normal file
55
gpt4all-backend/CMakeLists.txt
Normal file
@@ -0,0 +1,55 @@
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
|
||||
if(APPLE)
|
||||
option(BUILD_UNIVERSAL "Build a Universal binary on macOS" ON)
|
||||
if(BUILD_UNIVERSAL)
|
||||
# Build a Universal binary on macOS
|
||||
# This requires that the found Qt library is compiled as Universal binaries.
|
||||
set(CMAKE_OSX_ARCHITECTURES "arm64;x86_64" CACHE STRING "" FORCE)
|
||||
else()
|
||||
# Build for the host architecture on macOS
|
||||
set(CMAKE_OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}" CACHE STRING "" FORCE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Include the binary directory for the generated header file
|
||||
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
|
||||
|
||||
set(LLMODEL_VERSION_MAJOR 0)
|
||||
set(LLMODEL_VERSION_MINOR 1)
|
||||
set(LLMODEL_VERSION_PATCH 0)
|
||||
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
|
||||
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)
|
||||
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
set(LLAMA_BUILD_EXAMPLES ON CACHE BOOL "llama: build examples" FORCE)
|
||||
set(BUILD_SHARED_LIBS ON FORCE)
|
||||
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
if (GPT4ALL_AVX_ONLY)
|
||||
set(LLAMA_AVX2 OFF CACHE BOOL "llama: enable AVX2" FORCE)
|
||||
set(LLAMA_F16C OFF CACHE BOOL "llama: enable F16C" FORCE)
|
||||
set(LLAMA_FMA OFF CACHE BOOL "llama: enable FMA" FORCE)
|
||||
endif()
|
||||
|
||||
add_subdirectory(llama.cpp)
|
||||
|
||||
add_library(llmodel
|
||||
gptj.h gptj.cpp
|
||||
llamamodel.h llamamodel.cpp
|
||||
llama.cpp/examples/common.cpp
|
||||
llmodel.h llmodel_c.h llmodel_c.cpp
|
||||
mpt.h mpt.cpp
|
||||
utils.h utils.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(llmodel
|
||||
PRIVATE llama)
|
||||
|
||||
set_target_properties(llmodel PROPERTIES
|
||||
VERSION ${PROJECT_VERSION}
|
||||
SOVERSION ${PROJECT_VERSION_MAJOR})
|
||||
|
||||
set(COMPONENT_NAME_MAIN ${PROJECT_NAME})
|
||||
set(CMAKE_INSTALL_PREFIX ${CMAKE_BINARY_DIR}/install)
|
1102
gpt4all-backend/gptj.cpp
Normal file
1102
gpt4all-backend/gptj.cpp
Normal file
File diff suppressed because it is too large
Load Diff
36
gpt4all-backend/gptj.h
Normal file
36
gpt4all-backend/gptj.h
Normal file
@@ -0,0 +1,36 @@
|
||||
#ifndef GPTJ_H
|
||||
#define GPTJ_H
|
||||
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include "llmodel.h"
|
||||
|
||||
class GPTJPrivate;
|
||||
class GPTJ : public LLModel {
|
||||
public:
|
||||
GPTJ();
|
||||
~GPTJ();
|
||||
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() override;
|
||||
|
||||
protected:
|
||||
void recalculateContext(PromptContext &promptCtx,
|
||||
std::function<bool(bool)> recalculate) override;
|
||||
|
||||
private:
|
||||
GPTJPrivate *d_ptr;
|
||||
};
|
||||
|
||||
#endif // GPTJ_H
|
1
gpt4all-backend/llama.cpp
Submodule
1
gpt4all-backend/llama.cpp
Submodule
Submodule gpt4all-backend/llama.cpp added at 03ceb39c1e
260
gpt4all-backend/llamamodel.cpp
Normal file
260
gpt4all-backend/llamamodel.cpp
Normal file
@@ -0,0 +1,260 @@
|
||||
#include "llamamodel.h"
|
||||
|
||||
#include "llama.cpp/examples/common.h"
|
||||
#include "llama.cpp/llama.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <unistd.h>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
#include <unordered_set>
|
||||
|
||||
struct LLamaPrivate {
|
||||
const std::string modelPath;
|
||||
bool modelLoaded;
|
||||
llama_context *ctx = nullptr;
|
||||
llama_context_params params;
|
||||
int64_t n_threads = 0;
|
||||
};
|
||||
|
||||
LLamaModel::LLamaModel()
|
||||
: d_ptr(new LLamaPrivate) {
|
||||
|
||||
d_ptr->modelLoaded = false;
|
||||
}
|
||||
|
||||
bool LLamaModel::loadModel(const std::string &modelPath)
|
||||
{
|
||||
// load the model
|
||||
d_ptr->params = llama_context_default_params();
|
||||
|
||||
gpt_params params;
|
||||
d_ptr->params.n_ctx = 2048;
|
||||
d_ptr->params.n_parts = params.n_parts;
|
||||
d_ptr->params.seed = params.seed;
|
||||
d_ptr->params.f16_kv = params.memory_f16;
|
||||
d_ptr->params.use_mmap = params.use_mmap;
|
||||
d_ptr->params.use_mlock = params.use_mlock;
|
||||
|
||||
d_ptr->ctx = llama_init_from_file(modelPath.c_str(), d_ptr->params);
|
||||
if (!d_ptr->ctx) {
|
||||
std::cerr << "LLAMA ERROR: failed to load model from " << modelPath << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
d_ptr->modelLoaded = true;
|
||||
fflush(stderr);
|
||||
return true;
|
||||
}
|
||||
|
||||
void LLamaModel::setThreadCount(int32_t n_threads) {
|
||||
d_ptr->n_threads = n_threads;
|
||||
}
|
||||
|
||||
int32_t LLamaModel::threadCount() {
|
||||
return d_ptr->n_threads;
|
||||
}
|
||||
|
||||
LLamaModel::~LLamaModel()
|
||||
{
|
||||
llama_free(d_ptr->ctx);
|
||||
}
|
||||
|
||||
bool LLamaModel::isModelLoaded() const
|
||||
{
|
||||
return d_ptr->modelLoaded;
|
||||
}
|
||||
|
||||
size_t LLamaModel::stateSize() const
|
||||
{
|
||||
return llama_get_state_size(d_ptr->ctx);
|
||||
}
|
||||
|
||||
size_t LLamaModel::saveState(uint8_t *dest) const
|
||||
{
|
||||
return llama_copy_state_data(d_ptr->ctx, dest);
|
||||
}
|
||||
|
||||
size_t LLamaModel::restoreState(const uint8_t *src)
|
||||
{
|
||||
return llama_set_state_data(d_ptr->ctx, src);
|
||||
}
|
||||
|
||||
void LLamaModel::prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx) {
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
std::cerr << "LLAMA ERROR: prompt won't work with an unloaded model!\n";
|
||||
return;
|
||||
}
|
||||
|
||||
gpt_params params;
|
||||
params.prompt = prompt;
|
||||
|
||||
// Add a space in front of the first character to match OG llama tokenizer behavior
|
||||
params.prompt.insert(0, 1, ' ');
|
||||
|
||||
// tokenize the prompt
|
||||
auto embd_inp = ::llama_tokenize(d_ptr->ctx, params.prompt, false);
|
||||
|
||||
// save the context size
|
||||
promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx);
|
||||
|
||||
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
||||
responseCallback(-1, "The prompt size exceeds the context window size and cannot be processed.");
|
||||
std::cerr << "LLAMA ERROR: The prompt is" << embd_inp.size() <<
|
||||
"tokens and the context window is" << promptCtx.n_ctx << "!\n";
|
||||
return;
|
||||
}
|
||||
|
||||
promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size());
|
||||
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx);
|
||||
|
||||
// number of tokens to keep when resetting context
|
||||
params.n_keep = (int)embd_inp.size();
|
||||
|
||||
// process the prompt in batches
|
||||
size_t i = 0;
|
||||
const int64_t t_start_prompt_us = ggml_time_us();
|
||||
while (i < embd_inp.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
|
||||
std::vector<llama_token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
||||
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
|
||||
return;
|
||||
}
|
||||
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t) {
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(batch.at(t));
|
||||
if (!promptCallback(batch.at(t)))
|
||||
return;
|
||||
}
|
||||
promptCtx.n_past += batch.size();
|
||||
i = batch_end;
|
||||
}
|
||||
|
||||
std::string cachedResponse;
|
||||
std::vector<llama_token> cachedTokens;
|
||||
std::unordered_set<std::string> reversePrompts
|
||||
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" };
|
||||
|
||||
// predict next tokens
|
||||
int32_t totalPredictions = 0;
|
||||
for (int i = 0; i < promptCtx.n_predict; i++) {
|
||||
// sample next token
|
||||
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx,
|
||||
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.repeat_last_n,
|
||||
promptCtx.repeat_last_n, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
||||
promptCtx.repeat_penalty);
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
|
||||
std::cerr << "LLAMA ERROR: Failed to predict next token\n";
|
||||
return;
|
||||
}
|
||||
|
||||
promptCtx.n_past += 1;
|
||||
// display text
|
||||
++totalPredictions;
|
||||
if (id == llama_token_eos())
|
||||
return;
|
||||
|
||||
const std::string str = llama_token_to_str(d_ptr->ctx, id);
|
||||
|
||||
// Check if the provided str is part of our reverse prompts
|
||||
bool foundPartialReversePrompt = false;
|
||||
const std::string completed = cachedResponse + str;
|
||||
if (reversePrompts.find(completed) != reversePrompts.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if it partially matches our reverse prompts and if so, cache
|
||||
for (auto s : reversePrompts) {
|
||||
if (s.compare(0, completed.size(), completed) == 0) {
|
||||
foundPartialReversePrompt = true;
|
||||
cachedResponse = completed;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Regardless the token gets added to our cache
|
||||
cachedTokens.push_back(id);
|
||||
|
||||
// Continue if we have found a partial match
|
||||
if (foundPartialReversePrompt)
|
||||
continue;
|
||||
|
||||
// Empty the cache
|
||||
for (auto t : cachedTokens) {
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(t);
|
||||
if (!responseCallback(t, llama_token_to_str(d_ptr->ctx, t)))
|
||||
return;
|
||||
}
|
||||
cachedTokens.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
||||
{
|
||||
size_t i = 0;
|
||||
promptCtx.n_past = 0;
|
||||
while (i < promptCtx.tokens.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||
std::vector<llama_token> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
|
||||
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
||||
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
|
||||
goto stop_generating;
|
||||
}
|
||||
promptCtx.n_past += batch.size();
|
||||
if (!recalculate(true))
|
||||
goto stop_generating;
|
||||
i = batch_end;
|
||||
}
|
||||
assert(promptCtx.n_past == promptCtx.tokens.size());
|
||||
|
||||
stop_generating:
|
||||
recalculate(false);
|
||||
}
|
36
gpt4all-backend/llamamodel.h
Normal file
36
gpt4all-backend/llamamodel.h
Normal file
@@ -0,0 +1,36 @@
|
||||
#ifndef LLAMAMODEL_H
|
||||
#define LLAMAMODEL_H
|
||||
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include "llmodel.h"
|
||||
|
||||
class LLamaPrivate;
|
||||
class LLamaModel : public LLModel {
|
||||
public:
|
||||
LLamaModel();
|
||||
~LLamaModel();
|
||||
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() override;
|
||||
|
||||
protected:
|
||||
void recalculateContext(PromptContext &promptCtx,
|
||||
std::function<bool(bool)> recalculate) override;
|
||||
|
||||
private:
|
||||
LLamaPrivate *d_ptr;
|
||||
};
|
||||
|
||||
#endif // LLAMAMODEL_H
|
47
gpt4all-backend/llmodel.h
Normal file
47
gpt4all-backend/llmodel.h
Normal file
@@ -0,0 +1,47 @@
|
||||
#ifndef LLMODEL_H
|
||||
#define LLMODEL_H
|
||||
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
class LLModel {
|
||||
public:
|
||||
explicit LLModel() {}
|
||||
virtual ~LLModel() {}
|
||||
|
||||
virtual bool loadModel(const std::string &modelPath) = 0;
|
||||
virtual bool isModelLoaded() const = 0;
|
||||
virtual size_t stateSize() const { return 0; }
|
||||
virtual size_t saveState(uint8_t *dest) const { return 0; }
|
||||
virtual size_t restoreState(const uint8_t *src) { return 0; }
|
||||
struct PromptContext {
|
||||
std::vector<float> logits; // logits of current context
|
||||
std::vector<int32_t> tokens; // current tokens in the context window
|
||||
int32_t n_past = 0; // number of tokens in past conversation
|
||||
int32_t n_ctx = 0; // number of tokens possible in context window
|
||||
int32_t n_predict = 200;
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
float temp = 0.9f;
|
||||
int32_t n_batch = 9;
|
||||
float repeat_penalty = 1.10f;
|
||||
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||
float contextErase = 0.75f; // percent of context to erase if we exceed the context
|
||||
// window
|
||||
};
|
||||
virtual void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx) = 0;
|
||||
virtual void setThreadCount(int32_t n_threads) {}
|
||||
virtual int32_t threadCount() { return 1; }
|
||||
|
||||
protected:
|
||||
virtual void recalculateContext(PromptContext &promptCtx,
|
||||
std::function<bool(bool)> recalculate) = 0;
|
||||
};
|
||||
|
||||
#endif // LLMODEL_H
|
161
gpt4all-backend/llmodel_c.cpp
Normal file
161
gpt4all-backend/llmodel_c.cpp
Normal file
@@ -0,0 +1,161 @@
|
||||
#include "llmodel_c.h"
|
||||
|
||||
#include "gptj.h"
|
||||
#include "llamamodel.h"
|
||||
#include "mpt.h"
|
||||
|
||||
struct LLModelWrapper {
|
||||
LLModel *llModel = nullptr;
|
||||
LLModel::PromptContext promptContext;
|
||||
};
|
||||
|
||||
llmodel_model llmodel_gptj_create()
|
||||
{
|
||||
LLModelWrapper *wrapper = new LLModelWrapper;
|
||||
wrapper->llModel = new GPTJ;
|
||||
return reinterpret_cast<void*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_gptj_destroy(llmodel_model gptj)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(gptj);
|
||||
delete wrapper->llModel;
|
||||
delete wrapper;
|
||||
}
|
||||
|
||||
llmodel_model llmodel_mpt_create()
|
||||
{
|
||||
LLModelWrapper *wrapper = new LLModelWrapper;
|
||||
wrapper->llModel = new MPT;
|
||||
return reinterpret_cast<void*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_mpt_destroy(llmodel_model mpt)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(mpt);
|
||||
delete wrapper->llModel;
|
||||
delete wrapper;
|
||||
}
|
||||
|
||||
llmodel_model llmodel_llama_create()
|
||||
{
|
||||
LLModelWrapper *wrapper = new LLModelWrapper;
|
||||
wrapper->llModel = new LLamaModel;
|
||||
return reinterpret_cast<void*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_llama_destroy(llmodel_model llama)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(llama);
|
||||
delete wrapper->llModel;
|
||||
delete wrapper;
|
||||
}
|
||||
|
||||
bool llmodel_loadModel(llmodel_model model, const char *model_path)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->loadModel(model_path);
|
||||
}
|
||||
|
||||
bool llmodel_isModelLoaded(llmodel_model model)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->isModelLoaded();
|
||||
}
|
||||
|
||||
uint64_t llmodel_get_state_size(llmodel_model model)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->stateSize();
|
||||
}
|
||||
|
||||
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->saveState(dest);
|
||||
}
|
||||
|
||||
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->restoreState(src);
|
||||
}
|
||||
|
||||
// Wrapper functions for the C callbacks
|
||||
bool prompt_wrapper(int32_t token_id, void *user_data) {
|
||||
llmodel_prompt_callback callback = reinterpret_cast<llmodel_prompt_callback>(user_data);
|
||||
return callback(token_id);
|
||||
}
|
||||
|
||||
bool response_wrapper(int32_t token_id, const std::string &response, void *user_data) {
|
||||
llmodel_response_callback callback = reinterpret_cast<llmodel_response_callback>(user_data);
|
||||
return callback(token_id, response.c_str());
|
||||
}
|
||||
|
||||
bool recalculate_wrapper(bool is_recalculating, void *user_data) {
|
||||
llmodel_recalculate_callback callback = reinterpret_cast<llmodel_recalculate_callback>(user_data);
|
||||
return callback(is_recalculating);
|
||||
}
|
||||
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
llmodel_response_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
llmodel_recalculate_callback recalculate_callback,
|
||||
llmodel_prompt_context *ctx)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
|
||||
// Create std::function wrappers that call the C function pointers
|
||||
std::function<bool(int32_t)> prompt_func =
|
||||
std::bind(&prompt_wrapper, std::placeholders::_1, reinterpret_cast<void*>(prompt_callback));
|
||||
std::function<bool(int32_t, const std::string&)> response_func =
|
||||
std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast<void*>(response_callback));
|
||||
std::function<bool(bool)> recalc_func =
|
||||
std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast<void*>(recalculate_callback));
|
||||
|
||||
// Copy the C prompt context
|
||||
wrapper->promptContext.n_past = ctx->n_past;
|
||||
wrapper->promptContext.n_ctx = ctx->n_ctx;
|
||||
wrapper->promptContext.n_predict = ctx->n_predict;
|
||||
wrapper->promptContext.top_k = ctx->top_k;
|
||||
wrapper->promptContext.top_p = ctx->top_p;
|
||||
wrapper->promptContext.temp = ctx->temp;
|
||||
wrapper->promptContext.n_batch = ctx->n_batch;
|
||||
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
|
||||
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
|
||||
wrapper->promptContext.contextErase = ctx->context_erase;
|
||||
|
||||
// Call the C++ prompt method
|
||||
wrapper->llModel->prompt(prompt, prompt_func, response_func, recalc_func, wrapper->promptContext);
|
||||
|
||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||
// which involves no copies
|
||||
ctx->logits = wrapper->promptContext.logits.data();
|
||||
ctx->logits_size = wrapper->promptContext.logits.size();
|
||||
ctx->tokens = wrapper->promptContext.tokens.data();
|
||||
ctx->tokens_size = wrapper->promptContext.tokens.size();
|
||||
|
||||
// Update the rest of the C prompt context
|
||||
ctx->n_past = wrapper->promptContext.n_past;
|
||||
ctx->n_ctx = wrapper->promptContext.n_ctx;
|
||||
ctx->n_predict = wrapper->promptContext.n_predict;
|
||||
ctx->top_k = wrapper->promptContext.top_k;
|
||||
ctx->top_p = wrapper->promptContext.top_p;
|
||||
ctx->temp = wrapper->promptContext.temp;
|
||||
ctx->n_batch = wrapper->promptContext.n_batch;
|
||||
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
|
||||
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
|
||||
ctx->context_erase = wrapper->promptContext.contextErase;
|
||||
}
|
||||
|
||||
void llmodel_setThreadCount(llmodel_model model, int32_t n_threads)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
wrapper->llModel->setThreadCount(n_threads);
|
||||
}
|
||||
|
||||
int32_t llmodel_threadCount(llmodel_model model)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->threadCount();
|
||||
}
|
172
gpt4all-backend/llmodel_c.h
Normal file
172
gpt4all-backend/llmodel_c.h
Normal file
@@ -0,0 +1,172 @@
|
||||
#ifndef LLMODEL_C_H
|
||||
#define LLMODEL_C_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Opaque pointer to the underlying model.
|
||||
*/
|
||||
typedef void *llmodel_model;
|
||||
|
||||
/**
|
||||
* llmodel_prompt_context structure for holding the prompt context.
|
||||
* NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the
|
||||
* raw tokens pointer. Attempting to resize them or modify them in any way can lead to undefined
|
||||
* behavior.
|
||||
*/
|
||||
typedef struct {
|
||||
float *logits; // logits of current context
|
||||
size_t logits_size; // the size of the raw logits vector
|
||||
int32_t *tokens; // current tokens in the context window
|
||||
size_t tokens_size; // the size of the raw tokens vector
|
||||
int32_t n_past; // number of tokens in past conversation
|
||||
int32_t n_ctx; // number of tokens possible in context window
|
||||
int32_t n_predict; // number of tokens to predict
|
||||
int32_t top_k; // top k logits to sample from
|
||||
float top_p; // nucleus sampling probability threshold
|
||||
float temp; // temperature to adjust model's output distribution
|
||||
int32_t n_batch; // number of predictions to generate in parallel
|
||||
float repeat_penalty; // penalty factor for repeated tokens
|
||||
int32_t repeat_last_n; // last n tokens to penalize
|
||||
float context_erase; // percent of context to erase if we exceed the context window
|
||||
} llmodel_prompt_context;
|
||||
|
||||
/**
|
||||
* Callback type for prompt processing.
|
||||
* @param token_id The token id of the prompt.
|
||||
* @return a bool indicating whether the model should keep processing.
|
||||
*/
|
||||
typedef bool (*llmodel_prompt_callback)(int32_t token_id);
|
||||
|
||||
/**
|
||||
* Callback type for response.
|
||||
* @param token_id The token id of the response.
|
||||
* @param response The response string. NOTE: a token_id of -1 indicates the string is an error string.
|
||||
* @return a bool indicating whether the model should keep generating.
|
||||
*/
|
||||
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
|
||||
|
||||
/**
|
||||
* Callback type for recalculation of context.
|
||||
* @param whether the model is recalculating the context.
|
||||
* @return a bool indicating whether the model should keep generating.
|
||||
*/
|
||||
typedef bool (*llmodel_recalculate_callback)(bool is_recalculating);
|
||||
|
||||
/**
|
||||
* Create a GPTJ instance.
|
||||
* @return A pointer to the GPTJ instance.
|
||||
*/
|
||||
llmodel_model llmodel_gptj_create();
|
||||
|
||||
/**
|
||||
* Destroy a GPTJ instance.
|
||||
* @param gptj A pointer to the GPTJ instance.
|
||||
*/
|
||||
void llmodel_gptj_destroy(llmodel_model gptj);
|
||||
|
||||
/**
|
||||
* Create a MPT instance.
|
||||
* @return A pointer to the MPT instance.
|
||||
*/
|
||||
llmodel_model llmodel_mpt_create();
|
||||
|
||||
/**
|
||||
* Destroy a MPT instance.
|
||||
* @param gptj A pointer to the MPT instance.
|
||||
*/
|
||||
void llmodel_mpt_destroy(llmodel_model mpt);
|
||||
|
||||
/**
|
||||
* Create a LLAMA instance.
|
||||
* @return A pointer to the LLAMA instance.
|
||||
*/
|
||||
llmodel_model llmodel_llama_create();
|
||||
|
||||
/**
|
||||
* Destroy a LLAMA instance.
|
||||
* @param llama A pointer to the LLAMA instance.
|
||||
*/
|
||||
void llmodel_llama_destroy(llmodel_model llama);
|
||||
|
||||
/**
|
||||
* Load a model from a file.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param model_path A string representing the path to the model file.
|
||||
* @return true if the model was loaded successfully, false otherwise.
|
||||
*/
|
||||
bool llmodel_loadModel(llmodel_model model, const char *model_path);
|
||||
|
||||
/**
|
||||
* Check if a model is loaded.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @return true if the model is loaded, false otherwise.
|
||||
*/
|
||||
bool llmodel_isModelLoaded(llmodel_model model);
|
||||
|
||||
/**
|
||||
* Get the size of the internal state of the model.
|
||||
* NOTE: This state data is specific to the type of model you have created.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @return the size in bytes of the internal state of the model
|
||||
*/
|
||||
uint64_t llmodel_get_state_size(llmodel_model model);
|
||||
|
||||
/**
|
||||
* Saves the internal state of the model to the specified destination address.
|
||||
* NOTE: This state data is specific to the type of model you have created.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param dest A pointer to the destination.
|
||||
* @return the number of bytes copied
|
||||
*/
|
||||
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest);
|
||||
|
||||
/**
|
||||
* Restores the internal state of the model using data from the specified address.
|
||||
* NOTE: This state data is specific to the type of model you have created.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param src A pointer to the src.
|
||||
* @return the number of bytes read
|
||||
*/
|
||||
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
|
||||
|
||||
/**
|
||||
* Generate a response using the model.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param prompt A string representing the input prompt.
|
||||
* @param prompt_callback A callback function for handling the processing of prompt.
|
||||
* @param response_callback A callback function for handling the generated response.
|
||||
* @param recalculate_callback A callback function for handling recalculation requests.
|
||||
* @param ctx A pointer to the llmodel_prompt_context structure.
|
||||
*/
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
llmodel_response_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
llmodel_recalculate_callback recalculate_callback,
|
||||
llmodel_prompt_context *ctx);
|
||||
|
||||
/**
|
||||
* Set the number of threads to be used by the model.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param n_threads The number of threads to be used.
|
||||
*/
|
||||
void llmodel_setThreadCount(llmodel_model model, int32_t n_threads);
|
||||
|
||||
/**
|
||||
* Get the number of threads currently being used by the model.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @return The number of threads currently being used.
|
||||
*/
|
||||
int32_t llmodel_threadCount(llmodel_model model);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // LLMODEL_C_H
|
1240
gpt4all-backend/mpt.cpp
Normal file
1240
gpt4all-backend/mpt.cpp
Normal file
File diff suppressed because it is too large
Load Diff
36
gpt4all-backend/mpt.h
Normal file
36
gpt4all-backend/mpt.h
Normal file
@@ -0,0 +1,36 @@
|
||||
#ifndef MPT_H
|
||||
#define MPT_H
|
||||
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include "llmodel.h"
|
||||
|
||||
class MPTPrivate;
|
||||
class MPT : public LLModel {
|
||||
public:
|
||||
MPT();
|
||||
~MPT();
|
||||
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() override;
|
||||
|
||||
protected:
|
||||
void recalculateContext(PromptContext &promptCtx,
|
||||
std::function<bool(bool)> recalculate) override;
|
||||
|
||||
private:
|
||||
MPTPrivate *d_ptr;
|
||||
};
|
||||
|
||||
#endif // MPT_H
|
175
gpt4all-backend/scripts/convert_mpt_hf_to_ggml.py
Normal file
175
gpt4all-backend/scripts/convert_mpt_hf_to_ggml.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Convert Hugging Face fine-tuned bloom-like models to ggml format
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# python3 models/convert-h5-to-ggml.py
|
||||
#
|
||||
# This script is similar to "convert-pt-to-ggml.py"
|
||||
#
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import struct
|
||||
import json
|
||||
import code
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BloomForCausalLM
|
||||
|
||||
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: python convert-hf-to-ggml.py model_name dir-output [use-f32]")
|
||||
print(" model_name: name of the model to convert. Example: 'bigscience/bloomz-560m'")
|
||||
print(" dir-output: directory where the output file will be written")
|
||||
print(" use-f32: if present, use float32 instead of float16")
|
||||
sys.exit(1)
|
||||
|
||||
model_name = sys.argv[1]
|
||||
dir_out = sys.argv[2]
|
||||
|
||||
# make sure the output directory exists
|
||||
os.makedirs(dir_out, exist_ok=True)
|
||||
|
||||
# possible data types
|
||||
# ftype == 0 -> float32
|
||||
# ftype == 1 -> float16
|
||||
#
|
||||
# map from ftype to string
|
||||
ftype_str = ["f32", "f16"]
|
||||
ftype = 1
|
||||
if len(sys.argv) > 3:
|
||||
ftype = 0
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
hparams = config.to_dict()
|
||||
print("Loading model: ", model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, config=config, torch_dtype=torch.float16 if ftype == 1 else torch.float32, low_cpu_mem_usage=True)
|
||||
print("Model loaded: ", model_name)
|
||||
|
||||
|
||||
fname_out = dir_out + f"/ggml-model-{model_name.split('/')[-1]}-{ftype_str[ftype]}.bin"
|
||||
fout = open(fname_out, "wb")
|
||||
vocab = tokenizer.vocab
|
||||
|
||||
hparams["multiple_of"] = 1
|
||||
fout.write(struct.pack("i", 0x67676d6d)) # magic: ggml in hex
|
||||
fout.write(struct.pack("i", hparams["vocab_size"]))
|
||||
fout.write(struct.pack("i", hparams["max_seq_len"]))
|
||||
fout.write(struct.pack("i", hparams["d_model"]))
|
||||
fout.write(struct.pack("i", hparams["n_heads"]))
|
||||
fout.write(struct.pack("i", hparams["n_layers"]))
|
||||
# n_rot (unused)
|
||||
fout.write(struct.pack("i", 0))
|
||||
fout.write(struct.pack("i", ftype))
|
||||
|
||||
# # Is this correct??
|
||||
# dot_token = tokenizer.encode(".")[0]
|
||||
# write tokens to ggml file
|
||||
fout.write(struct.pack("i", hparams["vocab_size"]))
|
||||
|
||||
for i in range(hparams["vocab_size"]):
|
||||
text = tokenizer.decode([i]).encode('utf-8')
|
||||
fout.write(struct.pack("i", len(text)))
|
||||
fout.write(text)
|
||||
|
||||
list_vars = model.state_dict()
|
||||
for name in list_vars.keys():
|
||||
data = list_vars[name].squeeze().numpy()
|
||||
print("Processing variable: " + name + " with shape: ", data.shape)
|
||||
|
||||
# we don't need these
|
||||
if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"):
|
||||
print(" Skipping variable: " + name)
|
||||
continue
|
||||
|
||||
if "Wqkv.weight" in name:
|
||||
# chunk qkv
|
||||
query, key, value = np.split(data, 3, axis=0)
|
||||
|
||||
new_name = name.split("Wqkv.weight")[0]
|
||||
|
||||
for (data, name) in [(query, new_name + "q_proj.weight"), (key, new_name + "k_proj.weight"), (value, new_name + "v_proj.weight")]:
|
||||
print(f"Processing variable: {name} with shape: {data.shape}")
|
||||
n_dims = len(data.shape);
|
||||
|
||||
# ftype == 0 -> float32, ftype == 1 -> float16
|
||||
ftype_cur = 0;
|
||||
if ftype != 0:
|
||||
print(" Converting to float16")
|
||||
data = data.astype(np.float16)
|
||||
ftype_cur = 1
|
||||
else:
|
||||
if data.dtype != np.float32:
|
||||
print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
|
||||
# header
|
||||
str = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
|
||||
for i in range(n_dims):
|
||||
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
||||
fout.write(str);
|
||||
|
||||
# data
|
||||
data.tofile(fout)
|
||||
|
||||
else:
|
||||
|
||||
n_dims = len(data.shape);
|
||||
|
||||
# ftype == 0 -> float32, ftype == 1 -> float16
|
||||
ftype_cur = 0;
|
||||
if ftype != 0:
|
||||
if name[-7:] == ".weight" and n_dims == 2:
|
||||
print(" Converting to float16")
|
||||
data = data.astype(np.float16)
|
||||
ftype_cur = 1
|
||||
else:
|
||||
print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
else:
|
||||
if data.dtype != np.float32:
|
||||
print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
|
||||
# header
|
||||
str = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
|
||||
for i in range(n_dims):
|
||||
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
||||
fout.write(str);
|
||||
|
||||
# data
|
||||
data.tofile(fout)
|
||||
|
||||
fout.close()
|
||||
|
||||
print("Done. Output file: " + fname_out)
|
||||
print("")
|
274
gpt4all-backend/utils.cpp
Normal file
274
gpt4all-backend/utils.cpp
Normal file
@@ -0,0 +1,274 @@
|
||||
#include "utils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
|
||||
void replace(std::string & str, const std::string & needle, const std::string & replacement) {
|
||||
size_t pos = 0;
|
||||
while ((pos = str.find(needle, pos)) != std::string::npos) {
|
||||
str.replace(pos, needle.length(), replacement);
|
||||
pos += replacement.length();
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, int32_t> json_parse(const std::string & fname) {
|
||||
std::map<std::string, int32_t> result;
|
||||
|
||||
// read file into string
|
||||
std::string json;
|
||||
{
|
||||
std::ifstream ifs(fname);
|
||||
if (!ifs) {
|
||||
fprintf(stderr, "Failed to open %s\n", fname.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
json = std::string((std::istreambuf_iterator<char>(ifs)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
}
|
||||
|
||||
if (json[0] != '{') {
|
||||
return result;
|
||||
}
|
||||
|
||||
// parse json
|
||||
{
|
||||
bool has_key = false;
|
||||
bool in_token = false;
|
||||
|
||||
std::string str_key = "";
|
||||
std::string str_val = "";
|
||||
|
||||
int n = json.size();
|
||||
for (int i = 1; i < n; ++i) {
|
||||
if (!in_token) {
|
||||
if (json[i] == ' ') continue;
|
||||
if (json[i] == '"') {
|
||||
in_token = true;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (json[i] == '\\' && i+1 < n) {
|
||||
if (has_key == false) {
|
||||
str_key += json[i];
|
||||
} else {
|
||||
str_val += json[i];
|
||||
}
|
||||
++i;
|
||||
} else if (json[i] == '"') {
|
||||
if (has_key == false) {
|
||||
has_key = true;
|
||||
++i;
|
||||
while (json[i] == ' ') ++i;
|
||||
++i; // :
|
||||
while (json[i] == ' ') ++i;
|
||||
if (json[i] != '\"') {
|
||||
while (json[i] != ',' && json[i] != '}') {
|
||||
str_val += json[i++];
|
||||
}
|
||||
has_key = false;
|
||||
} else {
|
||||
in_token = true;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
has_key = false;
|
||||
}
|
||||
|
||||
::replace(str_key, "\\u0120", " " ); // \u0120 -> space
|
||||
::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
|
||||
::replace(str_key, "\\\"", "\""); // \\\" -> "
|
||||
|
||||
try {
|
||||
result[str_key] = std::stoi(str_val);
|
||||
} catch (...) {
|
||||
//fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
|
||||
|
||||
}
|
||||
str_key = "";
|
||||
str_val = "";
|
||||
in_token = false;
|
||||
continue;
|
||||
}
|
||||
if (has_key == false) {
|
||||
str_key += json[i];
|
||||
} else {
|
||||
str_val += json[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
||||
|
||||
vocab.token_to_id = ::json_parse(fname);
|
||||
|
||||
for (const auto & kv : vocab.token_to_id) {
|
||||
vocab.id_to_token[kv.second] = kv.first;
|
||||
}
|
||||
|
||||
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
|
||||
|
||||
// print the vocabulary
|
||||
//for (auto kv : vocab.token_to_id) {
|
||||
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
|
||||
//}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const int32_t * last_n_tokens_data,
|
||||
int last_n_tokens_size,
|
||||
const std::vector<float> logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
float repeat_penalty,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
||||
const auto * plogits = logits.data() + logits.size() - n_logits;
|
||||
|
||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
{
|
||||
const float scale = 1.0f/temp;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
|
||||
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
||||
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
||||
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if (plogits[i] < 0.0f) {
|
||||
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
|
||||
} else {
|
||||
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
|
||||
}
|
||||
} else {
|
||||
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
logits_id.resize(top_k);
|
||||
|
||||
double maxl = -INFINITY;
|
||||
for (const auto & kv : logits_id) {
|
||||
maxl = std::max(maxl, kv.first);
|
||||
}
|
||||
|
||||
// compute probs for the top K tokens
|
||||
std::vector<double> probs;
|
||||
probs.reserve(logits_id.size());
|
||||
|
||||
double sum = 0.0;
|
||||
for (const auto & kv : logits_id) {
|
||||
double p = exp(kv.first - maxl);
|
||||
probs.push_back(p);
|
||||
sum += p;
|
||||
}
|
||||
|
||||
// normalize the probs
|
||||
for (auto & p : probs) {
|
||||
p /= sum;
|
||||
}
|
||||
|
||||
if (top_p < 1.0f) {
|
||||
double cumsum = 0.0f;
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
cumsum += probs[i];
|
||||
if (cumsum >= top_p) {
|
||||
top_k = i + 1;
|
||||
probs.resize(top_k);
|
||||
logits_id.resize(top_k);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
cumsum = 1.0/cumsum;
|
||||
for (int i = 0; i < (int) probs.size(); i++) {
|
||||
probs[i] *= cumsum;
|
||||
}
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int) probs.size(); i++) {
|
||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
|
||||
//}
|
||||
//exit(0);
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
}
|
85
gpt4all-backend/utils.h
Normal file
85
gpt4all-backend/utils.h
Normal file
@@ -0,0 +1,85 @@
|
||||
// Various helper functions and utilities
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
|
||||
//
|
||||
// CLI argument parsing
|
||||
//
|
||||
|
||||
struct gpt_params {
|
||||
int32_t seed = -1; // RNG seed
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_predict = 200; // new tokens to predict
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
float temp = 0.9f;
|
||||
|
||||
int32_t n_batch = 8; // batch size for prompt processing
|
||||
|
||||
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
|
||||
std::string prompt;
|
||||
};
|
||||
|
||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||
|
||||
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
|
||||
|
||||
std::string gpt_random_prompt(std::mt19937 & rng);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
struct gpt_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
|
||||
std::map<token, id> token_to_id;
|
||||
std::map<id, token> id_to_token;
|
||||
};
|
||||
|
||||
void replace(std::string & str, const std::string & needle, const std::string & replacement);
|
||||
|
||||
// poor-man's JSON parsing
|
||||
std::map<std::string, int32_t> json_parse(const std::string & fname);
|
||||
|
||||
// split text into tokens
|
||||
//
|
||||
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
||||
//
|
||||
// Regex (Python):
|
||||
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
//
|
||||
// Regex (C++):
|
||||
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
||||
//
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
|
||||
|
||||
// load the tokens from encoder.json
|
||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
|
||||
|
||||
// sample next token given probabilities for each embedding
|
||||
//
|
||||
// - consider only the top K tokens
|
||||
// - from them, consider only the top tokens with cumulative probability > P
|
||||
//
|
||||
// TODO: not sure if this implementation is correct
|
||||
//
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const int32_t * last_n_tokens_data,
|
||||
int last_n_tokens_size,
|
||||
const std::vector<float> logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
float repeat_penalty,
|
||||
std::mt19937 & rng);
|
Reference in New Issue
Block a user