mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 12:06:54 +00:00
Add llama.cpp support for loading llama based models in the gui. We now
support loading both gptj derived models and llama derived models.
This commit is contained in:
parent
00cb5fe2a5
commit
71b308e914
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -1,3 +1,6 @@
|
|||||||
[submodule "ggml"]
|
[submodule "ggml"]
|
||||||
path = ggml
|
path = ggml
|
||||||
url = https://github.com/manyoso/ggml.git
|
url = https://github.com/manyoso/ggml.git
|
||||||
|
[submodule "llama.cpp"]
|
||||||
|
path = llama.cpp
|
||||||
|
url = https://github.com/manyoso/llama.cpp.git
|
||||||
|
@ -28,15 +28,19 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|||||||
|
|
||||||
find_package(Qt6 6.2 COMPONENTS Quick Svg REQUIRED)
|
find_package(Qt6 6.2 COMPONENTS Quick Svg REQUIRED)
|
||||||
|
|
||||||
set(GGML_BUILD_EXAMPLES ON CACHE BOOL "ggml: build examples" FORCE)
|
set(LLAMA_BUILD_EXAMPLES ON CACHE BOOL "llama: build examples" FORCE)
|
||||||
add_subdirectory(ggml)
|
set(BUILD_SHARED_LIBS ON FORCE)
|
||||||
|
add_subdirectory(llama.cpp)
|
||||||
|
|
||||||
qt_add_executable(chat
|
qt_add_executable(chat
|
||||||
main.cpp
|
main.cpp
|
||||||
download.h download.cpp
|
download.h download.cpp
|
||||||
gptj.h gptj.cpp
|
gptj.h gptj.cpp
|
||||||
|
llamamodel.h llamamodel.cpp
|
||||||
|
llama.cpp/examples/common.cpp
|
||||||
llm.h llm.cpp
|
llm.h llm.cpp
|
||||||
llmodel.h
|
llmodel.h
|
||||||
|
utils.h utils.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
qt_add_qml_module(chat
|
qt_add_qml_module(chat
|
||||||
@ -72,7 +76,7 @@ target_compile_definitions(chat
|
|||||||
target_link_libraries(chat
|
target_link_libraries(chat
|
||||||
PRIVATE Qt6::Quick Qt6::Svg)
|
PRIVATE Qt6::Quick Qt6::Svg)
|
||||||
target_link_libraries(chat
|
target_link_libraries(chat
|
||||||
PRIVATE ggml ggml_utils)
|
PRIVATE llama)
|
||||||
|
|
||||||
set(COMPONENT_NAME_MAIN ${PROJECT_NAME})
|
set(COMPONENT_NAME_MAIN ${PROJECT_NAME})
|
||||||
set(CMAKE_INSTALL_PREFIX ${CMAKE_BINARY_DIR}/install)
|
set(CMAKE_INSTALL_PREFIX ${CMAKE_BINARY_DIR}/install)
|
||||||
|
8
gptj.cpp
8
gptj.cpp
@ -1,5 +1,5 @@
|
|||||||
#include "gptj.h"
|
#include "gptj.h"
|
||||||
#include "ggml/ggml.h"
|
#include "llama.cpp/ggml.h"
|
||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
@ -644,6 +644,12 @@ GPTJ::GPTJ()
|
|||||||
d_ptr->modelLoaded = false;
|
d_ptr->modelLoaded = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GPTJ::loadModel(const std::string &modelPath)
|
||||||
|
{
|
||||||
|
std::cerr << "GPTJ ERROR: loading gpt model from file unsupported!\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
|
bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
|
||||||
std::mt19937 rng(time(NULL));
|
std::mt19937 rng(time(NULL));
|
||||||
d_ptr->rng = rng;
|
d_ptr->rng = rng;
|
||||||
|
1
gptj.h
1
gptj.h
@ -12,6 +12,7 @@ public:
|
|||||||
GPTJ();
|
GPTJ();
|
||||||
~GPTJ();
|
~GPTJ();
|
||||||
|
|
||||||
|
bool loadModel(const std::string &modelPath) override;
|
||||||
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||||
bool isModelLoaded() const override;
|
bool isModelLoaded() const override;
|
||||||
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||||
|
1
llama.cpp
Submodule
1
llama.cpp
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit c8c2c524827be8fd681a63f0e5a697b0bf4c587b
|
160
llamamodel.cpp
Normal file
160
llamamodel.cpp
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
#include "llamamodel.h"
|
||||||
|
|
||||||
|
#include "llama.cpp/examples/common.h"
|
||||||
|
#include "llama.cpp/llama.h"
|
||||||
|
#include "llama.cpp/ggml.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <iostream>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <random>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
struct LLamaPrivate {
|
||||||
|
const std::string modelPath;
|
||||||
|
bool modelLoaded;
|
||||||
|
llama_context *ctx = nullptr;
|
||||||
|
llama_context_params params;
|
||||||
|
int64_t n_threads = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
LLamaModel::LLamaModel()
|
||||||
|
: d_ptr(new LLamaPrivate) {
|
||||||
|
|
||||||
|
d_ptr->modelLoaded = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool LLamaModel::loadModel(const std::string &modelPath, std::istream &fin)
|
||||||
|
{
|
||||||
|
std::cerr << "LLAMA ERROR: loading llama model from stream unsupported!\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool LLamaModel::loadModel(const std::string &modelPath)
|
||||||
|
{
|
||||||
|
// load the model
|
||||||
|
d_ptr->params = llama_context_default_params();
|
||||||
|
d_ptr->ctx = llama_init_from_file(modelPath.c_str(), d_ptr->params);
|
||||||
|
if (!d_ptr->ctx) {
|
||||||
|
std::cerr << "LLAMA ERROR: failed to load model from " << modelPath << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||||
|
d_ptr->modelLoaded = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LLamaModel::setThreadCount(int32_t n_threads) {
|
||||||
|
d_ptr->n_threads = n_threads;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t LLamaModel::threadCount() {
|
||||||
|
return d_ptr->n_threads;
|
||||||
|
}
|
||||||
|
|
||||||
|
LLamaModel::~LLamaModel()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
bool LLamaModel::isModelLoaded() const
|
||||||
|
{
|
||||||
|
return d_ptr->modelLoaded;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||||
|
PromptContext &promptCtx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) {
|
||||||
|
|
||||||
|
if (!isModelLoaded()) {
|
||||||
|
std::cerr << "LLAMA ERROR: prompt won't work with an unloaded model!\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
gpt_params params;
|
||||||
|
params.prompt = prompt;
|
||||||
|
|
||||||
|
// Add a space in front of the first character to match OG llama tokenizer behavior
|
||||||
|
params.prompt.insert(0, 1, ' ');
|
||||||
|
|
||||||
|
// tokenize the prompt
|
||||||
|
auto embd_inp = ::llama_tokenize(d_ptr->ctx, params.prompt, false);
|
||||||
|
const int n_ctx = llama_n_ctx(d_ptr->ctx);
|
||||||
|
|
||||||
|
if ((int) embd_inp.size() > n_ctx - 4) {
|
||||||
|
std::cerr << "LLAMA ERROR: prompt is too long\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
n_predict = std::min(n_predict, n_ctx - (int) embd_inp.size());
|
||||||
|
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx);
|
||||||
|
|
||||||
|
// number of tokens to keep when resetting context
|
||||||
|
params.n_keep = (int)embd_inp.size();
|
||||||
|
|
||||||
|
// process the prompt in batches
|
||||||
|
size_t i = 0;
|
||||||
|
const int64_t t_start_prompt_us = ggml_time_us();
|
||||||
|
while (i < embd_inp.size()) {
|
||||||
|
size_t batch_end = std::min(i + n_batch, embd_inp.size());
|
||||||
|
std::vector<llama_token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||||
|
|
||||||
|
if (promptCtx.n_past + batch.size() > n_ctx) {
|
||||||
|
std::cerr << "eval n_ctx " << n_ctx << " n_past " << promptCtx.n_past << std::endl;
|
||||||
|
promptCtx.n_past = std::min(promptCtx.n_past, int(n_ctx - batch.size()));
|
||||||
|
std::cerr << "after n_ctx " << n_ctx << " n_past " << promptCtx.n_past << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
||||||
|
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// We pass a null string for each token to see if the user has asked us to stop...
|
||||||
|
size_t tokens = batch_end - i;
|
||||||
|
for (size_t t = 0; t < tokens; ++t)
|
||||||
|
if (!response(""))
|
||||||
|
return;
|
||||||
|
promptCtx.n_past += batch.size();
|
||||||
|
i = batch_end;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token> cachedTokens;
|
||||||
|
|
||||||
|
// predict next tokens
|
||||||
|
int32_t totalPredictions = 0;
|
||||||
|
for (int i = 0; i < n_predict; i++) {
|
||||||
|
// sample next token
|
||||||
|
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, {}, 0, top_k, top_p, temp, 1.0f);
|
||||||
|
|
||||||
|
if (promptCtx.n_past + 1 > n_ctx) {
|
||||||
|
std::cerr << "eval 2 n_ctx " << n_ctx << " n_past " << promptCtx.n_past << std::endl;
|
||||||
|
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx - 1);
|
||||||
|
std::cerr << "after 2 n_ctx " << n_ctx << " n_past " << promptCtx.n_past << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
|
||||||
|
std::cerr << "LLAMA ERROR: Failed to predict next token\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
cachedTokens.emplace_back(id);
|
||||||
|
|
||||||
|
for (int j = 0; j < cachedTokens.size(); ++j) {
|
||||||
|
llama_token cachedToken = cachedTokens.at(j);
|
||||||
|
promptCtx.n_past += 1;
|
||||||
|
// display text
|
||||||
|
++totalPredictions;
|
||||||
|
if (id == llama_token_eos() || !response(llama_token_to_str(d_ptr->ctx, cachedToken)))
|
||||||
|
goto stop_generating;
|
||||||
|
}
|
||||||
|
cachedTokens.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
stop_generating:
|
||||||
|
return;
|
||||||
|
}
|
28
llamamodel.h
Normal file
28
llamamodel.h
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
#ifndef LLAMAMODEL_H
|
||||||
|
#define LLAMAMODEL_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <functional>
|
||||||
|
#include <vector>
|
||||||
|
#include "llmodel.h"
|
||||||
|
|
||||||
|
class LLamaPrivate;
|
||||||
|
class LLamaModel : public LLModel {
|
||||||
|
public:
|
||||||
|
LLamaModel();
|
||||||
|
~LLamaModel();
|
||||||
|
|
||||||
|
bool loadModel(const std::string &modelPath) override;
|
||||||
|
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||||
|
bool isModelLoaded() const override;
|
||||||
|
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||||
|
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
|
||||||
|
float temp = 0.0f, int32_t n_batch = 9) override;
|
||||||
|
void setThreadCount(int32_t n_threads) override;
|
||||||
|
int32_t threadCount() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
LLamaPrivate *d_ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // LLAMAMODEL_H
|
16
llm.cpp
16
llm.cpp
@ -47,20 +47,32 @@ bool LLMObject::loadModelPrivate(const QString &modelName)
|
|||||||
return true;
|
return true;
|
||||||
|
|
||||||
if (isModelLoaded()) {
|
if (isModelLoaded()) {
|
||||||
|
resetContext();
|
||||||
delete m_llmodel;
|
delete m_llmodel;
|
||||||
m_llmodel = nullptr;
|
m_llmodel = nullptr;
|
||||||
emit isModelLoadedChanged();
|
emit isModelLoadedChanged();
|
||||||
}
|
}
|
||||||
|
|
||||||
m_llmodel = new GPTJ;
|
bool isGPTJ = false;
|
||||||
|
|
||||||
QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() +
|
QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() +
|
||||||
"ggml-" + modelName + ".bin";
|
"ggml-" + modelName + ".bin";
|
||||||
QFileInfo info(filePath);
|
QFileInfo info(filePath);
|
||||||
if (info.exists()) {
|
if (info.exists()) {
|
||||||
|
|
||||||
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
||||||
|
|
||||||
|
uint32_t magic;
|
||||||
|
fin.read((char *) &magic, sizeof(magic));
|
||||||
|
fin.seekg(0);
|
||||||
|
isGPTJ = magic == 0x67676d6c;
|
||||||
|
if (isGPTJ) {
|
||||||
|
m_llmodel = new GPTJ;
|
||||||
m_llmodel->loadModel(modelName.toStdString(), fin);
|
m_llmodel->loadModel(modelName.toStdString(), fin);
|
||||||
|
} else {
|
||||||
|
m_llmodel = new LLamaModel;
|
||||||
|
m_llmodel->loadModel(filePath.toStdString());
|
||||||
|
}
|
||||||
|
|
||||||
emit isModelLoadedChanged();
|
emit isModelLoadedChanged();
|
||||||
emit threadCountChanged();
|
emit threadCountChanged();
|
||||||
}
|
}
|
||||||
|
1
llm.h
1
llm.h
@ -4,6 +4,7 @@
|
|||||||
#include <QObject>
|
#include <QObject>
|
||||||
#include <QThread>
|
#include <QThread>
|
||||||
#include "gptj.h"
|
#include "gptj.h"
|
||||||
|
#include "llamamodel.h"
|
||||||
|
|
||||||
class LLMObject : public QObject
|
class LLMObject : public QObject
|
||||||
{
|
{
|
||||||
|
@ -10,6 +10,7 @@ public:
|
|||||||
explicit LLModel() {}
|
explicit LLModel() {}
|
||||||
virtual ~LLModel() {}
|
virtual ~LLModel() {}
|
||||||
|
|
||||||
|
virtual bool loadModel(const std::string &modelPath) = 0;
|
||||||
virtual bool loadModel(const std::string &modelPath, std::istream &fin) = 0;
|
virtual bool loadModel(const std::string &modelPath, std::istream &fin) = 0;
|
||||||
virtual bool isModelLoaded() const = 0;
|
virtual bool isModelLoaded() const = 0;
|
||||||
struct PromptContext {
|
struct PromptContext {
|
||||||
@ -19,8 +20,8 @@ public:
|
|||||||
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
|
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
|
||||||
float temp = 0.9f, int32_t n_batch = 9) = 0;
|
float temp = 0.9f, int32_t n_batch = 9) = 0;
|
||||||
virtual void setThreadCount(int32_t n_threads);
|
virtual void setThreadCount(int32_t n_threads) {}
|
||||||
virtual int32_t threadCount();
|
virtual int32_t threadCount() { return 1; }
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // LLMODEL_H
|
#endif // LLMODEL_H
|
||||||
|
8
main.qml
8
main.qml
@ -70,7 +70,9 @@ Window {
|
|||||||
}
|
}
|
||||||
|
|
||||||
onActivated: {
|
onActivated: {
|
||||||
|
LLM.stopGenerating()
|
||||||
LLM.modelName = comboBox.currentText
|
LLM.modelName = comboBox.currentText
|
||||||
|
chatModel.clear()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -775,7 +777,7 @@ Window {
|
|||||||
Accessible.description: qsTr("This is the list of prompt/response pairs comprising the actual conversation with the model")
|
Accessible.description: qsTr("This is the list of prompt/response pairs comprising the actual conversation with the model")
|
||||||
|
|
||||||
delegate: TextArea {
|
delegate: TextArea {
|
||||||
text: currentResponse ? LLM.response : value
|
text: currentResponse ? LLM.response : (value ? value : "")
|
||||||
width: listView.width
|
width: listView.width
|
||||||
color: "#d1d5db"
|
color: "#d1d5db"
|
||||||
wrapMode: Text.WordWrap
|
wrapMode: Text.WordWrap
|
||||||
@ -800,8 +802,8 @@ Window {
|
|||||||
anchors.leftMargin: 90
|
anchors.leftMargin: 90
|
||||||
anchors.top: parent.top
|
anchors.top: parent.top
|
||||||
anchors.topMargin: 5
|
anchors.topMargin: 5
|
||||||
visible: currentResponse && LLM.response === "" && LLM.responseInProgress
|
visible: (currentResponse ? true : false) && LLM.response === "" && LLM.responseInProgress
|
||||||
running: currentResponse && LLM.response === "" && LLM.responseInProgress
|
running: (currentResponse ? true : false) && LLM.response === "" && LLM.responseInProgress
|
||||||
|
|
||||||
Accessible.role: Accessible.Animation
|
Accessible.role: Accessible.Animation
|
||||||
Accessible.name: qsTr("Busy indicator")
|
Accessible.name: qsTr("Busy indicator")
|
||||||
|
257
utils.cpp
Normal file
257
utils.cpp
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
|
void replace(std::string & str, const std::string & needle, const std::string & replacement) {
|
||||||
|
size_t pos = 0;
|
||||||
|
while ((pos = str.find(needle, pos)) != std::string::npos) {
|
||||||
|
str.replace(pos, needle.length(), replacement);
|
||||||
|
pos += replacement.length();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, int32_t> json_parse(const std::string & fname) {
|
||||||
|
std::map<std::string, int32_t> result;
|
||||||
|
|
||||||
|
// read file into string
|
||||||
|
std::string json;
|
||||||
|
{
|
||||||
|
std::ifstream ifs(fname);
|
||||||
|
if (!ifs) {
|
||||||
|
fprintf(stderr, "Failed to open %s\n", fname.c_str());
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
json = std::string((std::istreambuf_iterator<char>(ifs)),
|
||||||
|
(std::istreambuf_iterator<char>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (json[0] != '{') {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse json
|
||||||
|
{
|
||||||
|
bool has_key = false;
|
||||||
|
bool in_token = false;
|
||||||
|
|
||||||
|
std::string str_key = "";
|
||||||
|
std::string str_val = "";
|
||||||
|
|
||||||
|
int n = json.size();
|
||||||
|
for (int i = 1; i < n; ++i) {
|
||||||
|
if (!in_token) {
|
||||||
|
if (json[i] == ' ') continue;
|
||||||
|
if (json[i] == '"') {
|
||||||
|
in_token = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (json[i] == '\\' && i+1 < n) {
|
||||||
|
if (has_key == false) {
|
||||||
|
str_key += json[i];
|
||||||
|
} else {
|
||||||
|
str_val += json[i];
|
||||||
|
}
|
||||||
|
++i;
|
||||||
|
} else if (json[i] == '"') {
|
||||||
|
if (has_key == false) {
|
||||||
|
has_key = true;
|
||||||
|
++i;
|
||||||
|
while (json[i] == ' ') ++i;
|
||||||
|
++i; // :
|
||||||
|
while (json[i] == ' ') ++i;
|
||||||
|
if (json[i] != '\"') {
|
||||||
|
while (json[i] != ',' && json[i] != '}') {
|
||||||
|
str_val += json[i++];
|
||||||
|
}
|
||||||
|
has_key = false;
|
||||||
|
} else {
|
||||||
|
in_token = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
has_key = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
::replace(str_key, "\\u0120", " " ); // \u0120 -> space
|
||||||
|
::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
|
||||||
|
::replace(str_key, "\\\"", "\""); // \\\" -> "
|
||||||
|
|
||||||
|
try {
|
||||||
|
result[str_key] = std::stoi(str_val);
|
||||||
|
} catch (...) {
|
||||||
|
//fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
|
||||||
|
|
||||||
|
}
|
||||||
|
str_key = "";
|
||||||
|
str_val = "";
|
||||||
|
in_token = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (has_key == false) {
|
||||||
|
str_key += json[i];
|
||||||
|
} else {
|
||||||
|
str_val += json[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||||
|
std::vector<std::string> words;
|
||||||
|
|
||||||
|
// first split the text into words
|
||||||
|
{
|
||||||
|
std::string str = text;
|
||||||
|
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||||
|
|
||||||
|
std::regex re(pat);
|
||||||
|
std::smatch m;
|
||||||
|
|
||||||
|
while (std::regex_search(str, m, re)) {
|
||||||
|
for (auto x : m) {
|
||||||
|
words.push_back(x);
|
||||||
|
}
|
||||||
|
str = m.suffix();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the longest tokens that form the words:
|
||||||
|
std::vector<gpt_vocab::id> tokens;
|
||||||
|
for (const auto & word : words) {
|
||||||
|
if (word.size() == 0) continue;
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
int n = word.size();
|
||||||
|
while (i < n) {
|
||||||
|
int j = n;
|
||||||
|
while (j > i) {
|
||||||
|
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||||
|
if (it != vocab.token_to_id.end()) {
|
||||||
|
tokens.push_back(it->second);
|
||||||
|
i = j;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
--j;
|
||||||
|
}
|
||||||
|
if (i == n) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (j == i) {
|
||||||
|
auto sub = word.substr(i, 1);
|
||||||
|
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||||
|
tokens.push_back(vocab.token_to_id.at(sub));
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||||
|
}
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||||
|
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
||||||
|
|
||||||
|
vocab.token_to_id = ::json_parse(fname);
|
||||||
|
|
||||||
|
for (const auto & kv : vocab.token_to_id) {
|
||||||
|
vocab.id_to_token[kv.second] = kv.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
|
||||||
|
|
||||||
|
// print the vocabulary
|
||||||
|
//for (auto kv : vocab.token_to_id) {
|
||||||
|
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
|
||||||
|
//}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||||
|
const gpt_vocab & vocab,
|
||||||
|
const float * logits,
|
||||||
|
int top_k,
|
||||||
|
double top_p,
|
||||||
|
double temp,
|
||||||
|
std::mt19937 & rng) {
|
||||||
|
int n_logits = vocab.id_to_token.size();
|
||||||
|
|
||||||
|
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||||
|
logits_id.reserve(n_logits);
|
||||||
|
|
||||||
|
{
|
||||||
|
const double scale = 1.0/temp;
|
||||||
|
for (int i = 0; i < n_logits; ++i) {
|
||||||
|
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the top K tokens
|
||||||
|
std::partial_sort(
|
||||||
|
logits_id.begin(),
|
||||||
|
logits_id.begin() + top_k, logits_id.end(),
|
||||||
|
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||||
|
return a.first > b.first;
|
||||||
|
});
|
||||||
|
|
||||||
|
logits_id.resize(top_k);
|
||||||
|
|
||||||
|
double maxl = -INFINITY;
|
||||||
|
for (const auto & kv : logits_id) {
|
||||||
|
maxl = std::max(maxl, kv.first);
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute probs for the top K tokens
|
||||||
|
std::vector<double> probs;
|
||||||
|
probs.reserve(logits_id.size());
|
||||||
|
|
||||||
|
double sum = 0.0;
|
||||||
|
for (const auto & kv : logits_id) {
|
||||||
|
double p = exp(kv.first - maxl);
|
||||||
|
probs.push_back(p);
|
||||||
|
sum += p;
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize the probs
|
||||||
|
for (auto & p : probs) {
|
||||||
|
p /= sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (top_p < 1.0f) {
|
||||||
|
double cumsum = 0.0f;
|
||||||
|
for (int i = 0; i < top_k; i++) {
|
||||||
|
cumsum += probs[i];
|
||||||
|
if (cumsum >= top_p) {
|
||||||
|
top_k = i + 1;
|
||||||
|
probs.resize(top_k);
|
||||||
|
logits_id.resize(top_k);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cumsum = 1.0/cumsum;
|
||||||
|
for (int i = 0; i < (int) probs.size(); i++) {
|
||||||
|
probs[i] *= cumsum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//printf("\n");
|
||||||
|
//for (int i = 0; i < (int) probs.size(); i++) {
|
||||||
|
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
|
||||||
|
//}
|
||||||
|
//exit(0);
|
||||||
|
|
||||||
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||||
|
int idx = dist(rng);
|
||||||
|
|
||||||
|
return logits_id[idx].second;
|
||||||
|
}
|
83
utils.h
Normal file
83
utils.h
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
// Various helper functions and utilities
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <random>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
//
|
||||||
|
// CLI argument parsing
|
||||||
|
//
|
||||||
|
|
||||||
|
struct gpt_params {
|
||||||
|
int32_t seed = -1; // RNG seed
|
||||||
|
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||||
|
int32_t n_predict = 200; // new tokens to predict
|
||||||
|
|
||||||
|
// sampling parameters
|
||||||
|
int32_t top_k = 40;
|
||||||
|
float top_p = 0.9f;
|
||||||
|
float temp = 0.9f;
|
||||||
|
|
||||||
|
int32_t n_batch = 8; // batch size for prompt processing
|
||||||
|
|
||||||
|
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
|
||||||
|
std::string prompt;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||||
|
|
||||||
|
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
|
||||||
|
|
||||||
|
std::string gpt_random_prompt(std::mt19937 & rng);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Vocab utils
|
||||||
|
//
|
||||||
|
|
||||||
|
struct gpt_vocab {
|
||||||
|
using id = int32_t;
|
||||||
|
using token = std::string;
|
||||||
|
|
||||||
|
std::map<token, id> token_to_id;
|
||||||
|
std::map<id, token> id_to_token;
|
||||||
|
};
|
||||||
|
|
||||||
|
void replace(std::string & str, const std::string & needle, const std::string & replacement);
|
||||||
|
|
||||||
|
// poor-man's JSON parsing
|
||||||
|
std::map<std::string, int32_t> json_parse(const std::string & fname);
|
||||||
|
|
||||||
|
// split text into tokens
|
||||||
|
//
|
||||||
|
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
||||||
|
//
|
||||||
|
// Regex (Python):
|
||||||
|
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||||
|
//
|
||||||
|
// Regex (C++):
|
||||||
|
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
||||||
|
//
|
||||||
|
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
|
||||||
|
|
||||||
|
// load the tokens from encoder.json
|
||||||
|
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
|
||||||
|
|
||||||
|
// sample next token given probabilities for each embedding
|
||||||
|
//
|
||||||
|
// - consider only the top K tokens
|
||||||
|
// - from them, consider only the top tokens with cumulative probability > P
|
||||||
|
//
|
||||||
|
// TODO: not sure if this implementation is correct
|
||||||
|
// TODO: temperature is not implemented
|
||||||
|
//
|
||||||
|
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||||
|
const gpt_vocab & vocab,
|
||||||
|
const float * logits,
|
||||||
|
int top_k,
|
||||||
|
double top_p,
|
||||||
|
double temp,
|
||||||
|
std::mt19937 & rng);
|
Loading…
Reference in New Issue
Block a user