mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-25 23:13:06 +00:00
backend: dedupe tokenizing code in gptj/mpt
This commit is contained in:
parent
16b7bf01a8
commit
fc2869f0b7
@ -1,6 +1,8 @@
|
|||||||
#include "mpt.h"
|
#include "mpt.h"
|
||||||
#include "llama.cpp/ggml.h"
|
#include "llama.cpp/ggml.h"
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
@ -136,26 +138,8 @@ static bool kv_cache_init(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct mpt_vocab {
|
|
||||||
using id = int32_t;
|
|
||||||
using token = std::string;
|
|
||||||
|
|
||||||
std::map<token, id> token_to_id;
|
|
||||||
std::map<id, token> id_to_token;
|
|
||||||
std::vector<std::string> special_tokens;
|
|
||||||
|
|
||||||
void add_special_token(const std::string &token) {
|
|
||||||
special_tokens.push_back(token);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string regex_escape(const std::string &s) {
|
|
||||||
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
|
|
||||||
return std::regex_replace(s, metacharacters, "\\$&");
|
|
||||||
}
|
|
||||||
|
|
||||||
// load the model's weights from a stream
|
// load the model's weights from a stream
|
||||||
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab) {
|
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab & vocab) {
|
||||||
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||||
|
|
||||||
// verify magic
|
// verify magic
|
||||||
@ -219,8 +203,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
|||||||
vocab.id_to_token[i] = word;
|
vocab.id_to_token[i] = word;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this only kind-of works, the gpt_tokenize can still incorrectly
|
|
||||||
// tokenize special tokens
|
|
||||||
if(special) {
|
if(special) {
|
||||||
vocab.add_special_token(word);
|
vocab.add_special_token(word);
|
||||||
}
|
}
|
||||||
@ -436,7 +418,7 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load the model's weights from a file path
|
// load the model's weights from a file path
|
||||||
bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) {
|
bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vocab) {
|
||||||
|
|
||||||
auto fin = std::ifstream(fname, std::ios::binary);
|
auto fin = std::ifstream(fname, std::ios::binary);
|
||||||
if (!fin) {
|
if (!fin) {
|
||||||
@ -647,98 +629,6 @@ bool mpt_eval(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> mpt_tokenize_inner(const mpt_vocab & vocab, const std::string & text) {
|
|
||||||
// taken from stablelm example in ggml
|
|
||||||
// they both use the gpt-neox tokenizer
|
|
||||||
// not sure if this entirely right?
|
|
||||||
std::vector<std::string> words;
|
|
||||||
|
|
||||||
|
|
||||||
// first split the text into words
|
|
||||||
{
|
|
||||||
std::string str = text;
|
|
||||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
|
||||||
std::regex re(pat);
|
|
||||||
std::smatch m;
|
|
||||||
|
|
||||||
while (std::regex_search(str, m, re)) {
|
|
||||||
for (auto x : m) {
|
|
||||||
words.push_back(x);
|
|
||||||
}
|
|
||||||
str = m.suffix();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// find the longest tokens that form the words:
|
|
||||||
std::vector<mpt_vocab::id> tokens;
|
|
||||||
for (const auto & word : words) {
|
|
||||||
if (word.size() == 0) continue;
|
|
||||||
|
|
||||||
int i = 0;
|
|
||||||
int n = word.size();
|
|
||||||
while (i < n) {
|
|
||||||
int j = n;
|
|
||||||
while (j > i) {
|
|
||||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
|
||||||
if (it != vocab.token_to_id.end()) {
|
|
||||||
tokens.push_back(it->second);
|
|
||||||
i = j;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
--j;
|
|
||||||
}
|
|
||||||
if (i == n) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (j == i) {
|
|
||||||
auto sub = word.substr(i, 1);
|
|
||||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
|
||||||
tokens.push_back(vocab.token_to_id.at(sub));
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
|
||||||
}
|
|
||||||
++i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<mpt_vocab::id> mpt_tokenize(const mpt_vocab & vocab, const std::string & text) {
|
|
||||||
// Generate the subpattern from the special_tokens vector if it's not empty
|
|
||||||
if (!vocab.special_tokens.empty()) {
|
|
||||||
std::vector<mpt_vocab::id> out;
|
|
||||||
std::vector<std::string> chunks;
|
|
||||||
std::string str = text;
|
|
||||||
std::string special_tokens_subpattern;
|
|
||||||
for (const auto &token : vocab.special_tokens) {
|
|
||||||
if (!special_tokens_subpattern.empty()) {
|
|
||||||
special_tokens_subpattern += "|";
|
|
||||||
}
|
|
||||||
special_tokens_subpattern += regex_escape(token);
|
|
||||||
}
|
|
||||||
std::regex re(special_tokens_subpattern);
|
|
||||||
std::smatch m;
|
|
||||||
while (std::regex_search(str, m, re)) {
|
|
||||||
auto tok = vocab.token_to_id.find(m.str());
|
|
||||||
if (tok != vocab.token_to_id.end()) {
|
|
||||||
auto tokid = tok->second;
|
|
||||||
auto pfxtoks = mpt_tokenize_inner(vocab, m.prefix());
|
|
||||||
out.insert(out.end(), pfxtoks.begin(), pfxtoks.end());
|
|
||||||
out.push_back(tokid);
|
|
||||||
str = m.suffix();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!str.empty()) {
|
|
||||||
auto tokrest = mpt_tokenize_inner(vocab, str);
|
|
||||||
out.insert(out.end(), tokrest.begin(), tokrest.end());
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
} else {
|
|
||||||
return mpt_tokenize_inner(vocab, text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define MPT_MAX_RNG_STATE 64*1024
|
#define MPT_MAX_RNG_STATE 64*1024
|
||||||
|
|
||||||
@ -801,8 +691,8 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint
|
|||||||
return written;
|
return written;
|
||||||
}
|
}
|
||||||
|
|
||||||
mpt_vocab::id mpt_sample_top_k_top_p(
|
gpt_vocab::id mpt_sample_top_k_top_p(
|
||||||
const mpt_vocab & vocab,
|
const gpt_vocab & vocab,
|
||||||
const size_t actualVocabSize,
|
const size_t actualVocabSize,
|
||||||
const int32_t * last_n_tokens_data,
|
const int32_t * last_n_tokens_data,
|
||||||
int last_n_tokens_size,
|
int last_n_tokens_size,
|
||||||
@ -817,7 +707,7 @@ mpt_vocab::id mpt_sample_top_k_top_p(
|
|||||||
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
||||||
const auto * plogits = logits.data() + logits.size() - n_logits;
|
const auto * plogits = logits.data() + logits.size() - n_logits;
|
||||||
|
|
||||||
std::vector<std::pair<double, mpt_vocab::id>> logits_id;
|
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||||
logits_id.reserve(n_logits);
|
logits_id.reserve(n_logits);
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -842,7 +732,7 @@ mpt_vocab::id mpt_sample_top_k_top_p(
|
|||||||
std::partial_sort(
|
std::partial_sort(
|
||||||
logits_id.begin(),
|
logits_id.begin(),
|
||||||
logits_id.begin() + top_k, logits_id.end(),
|
logits_id.begin() + top_k, logits_id.end(),
|
||||||
[](const std::pair<double, mpt_vocab::id> & a, const std::pair<double, mpt_vocab::id> & b) {
|
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||||
return a.first > b.first;
|
return a.first > b.first;
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -952,7 +842,7 @@ size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *sr
|
|||||||
struct MPTPrivate {
|
struct MPTPrivate {
|
||||||
const std::string modelPath;
|
const std::string modelPath;
|
||||||
bool modelLoaded;
|
bool modelLoaded;
|
||||||
mpt_vocab vocab;
|
gpt_vocab vocab;
|
||||||
mpt_model *model = nullptr;
|
mpt_model *model = nullptr;
|
||||||
int64_t n_threads = 0;
|
int64_t n_threads = 0;
|
||||||
size_t mem_per_token = 0;
|
size_t mem_per_token = 0;
|
||||||
@ -1037,7 +927,7 @@ void MPT::prompt(const std::string &prompt,
|
|||||||
int64_t t_prompt_us = 0;
|
int64_t t_prompt_us = 0;
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<int> embd_inp = mpt_tokenize(d_ptr->vocab, prompt);
|
std::vector<int> embd_inp = gpt_tokenize(d_ptr->vocab, prompt);
|
||||||
|
|
||||||
// save the context size
|
// save the context size
|
||||||
promptCtx.n_ctx = d_ptr->model->hparams.n_ctx;
|
promptCtx.n_ctx = d_ptr->model->hparams.n_ctx;
|
||||||
|
@ -102,7 +102,7 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
std::vector<gpt_vocab::id> gpt_tokenize_inner(const gpt_vocab & vocab, const std::string & text) {
|
||||||
std::vector<std::string> words;
|
std::vector<std::string> words;
|
||||||
|
|
||||||
// first split the text into words
|
// first split the text into words
|
||||||
@ -157,6 +157,47 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
|
|||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string regex_escape(const std::string &s) {
|
||||||
|
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
|
||||||
|
return std::regex_replace(s, metacharacters, "\\$&");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||||
|
// Generate the subpattern from the special_tokens vector if it's not empty
|
||||||
|
if (!vocab.special_tokens.empty()) {
|
||||||
|
std::vector<gpt_vocab::id> out;
|
||||||
|
std::vector<std::string> chunks;
|
||||||
|
std::string str = text;
|
||||||
|
std::string special_tokens_subpattern;
|
||||||
|
for (const auto &token : vocab.special_tokens) {
|
||||||
|
if (!special_tokens_subpattern.empty()) {
|
||||||
|
special_tokens_subpattern += "|";
|
||||||
|
}
|
||||||
|
special_tokens_subpattern += regex_escape(token);
|
||||||
|
}
|
||||||
|
std::regex re(special_tokens_subpattern);
|
||||||
|
std::smatch m;
|
||||||
|
while (std::regex_search(str, m, re)) {
|
||||||
|
auto tok = vocab.token_to_id.find(m.str());
|
||||||
|
if (tok != vocab.token_to_id.end()) {
|
||||||
|
auto tokid = tok->second;
|
||||||
|
auto pfxtoks = gpt_tokenize_inner(vocab, m.prefix());
|
||||||
|
out.insert(out.end(), pfxtoks.begin(), pfxtoks.end());
|
||||||
|
out.push_back(tokid);
|
||||||
|
str = m.suffix();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!str.empty()) {
|
||||||
|
auto tokrest = gpt_tokenize_inner(vocab, str);
|
||||||
|
out.insert(out.end(), tokrest.begin(), tokrest.end());
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
} else {
|
||||||
|
return gpt_tokenize_inner(vocab, text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||||
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
||||||
|
|
||||||
|
@ -44,6 +44,11 @@ struct gpt_vocab {
|
|||||||
|
|
||||||
std::map<token, id> token_to_id;
|
std::map<token, id> token_to_id;
|
||||||
std::map<id, token> id_to_token;
|
std::map<id, token> id_to_token;
|
||||||
|
std::vector<std::string> special_tokens;
|
||||||
|
|
||||||
|
void add_special_token(const std::string &token) {
|
||||||
|
special_tokens.push_back(token);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void replace(std::string & str, const std::string & needle, const std::string & replacement);
|
void replace(std::string & str, const std::string & needle, const std::string & replacement);
|
||||||
|
Loading…
Reference in New Issue
Block a user