mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-07 03:20:26 +00:00
New tokenizer implementation for MPT and GPT-J
Improves output quality by making these tokenizers more closely match the behavior of the huggingface `tokenizers` based BPE tokenizers these models were trained with. Featuring: * Fixed unicode handling (via ICU) * Fixed BPE token merge handling * Complete added vocabulary handling
This commit is contained in:
@@ -1,220 +1,49 @@
|
||||
#include "utils.h"
|
||||
#include "tokenizer/bpe.h"
|
||||
#include "tokenizer/mpt_tokenizer_config.h"
|
||||
#include "tokenizer/gptj_tokenizer_config.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
#include <stdexcept>
|
||||
|
||||
void replace(std::string & str, const std::string & needle, const std::string & replacement) {
|
||||
size_t pos = 0;
|
||||
while ((pos = str.find(needle, pos)) != std::string::npos) {
|
||||
str.replace(pos, needle.length(), replacement);
|
||||
pos += replacement.length();
|
||||
}
|
||||
}
|
||||
void get_bpecpp_tokenizer(const TokenizerType ttype, std::unique_ptr<bpecpp::BPE>& bpe, std::unique_ptr<bpecpp::AdditionalVocabAdapter>& av) {
|
||||
std::vector<bpecpp::additional_vocab_item> avis;
|
||||
std::unordered_map<std::string_view, uint32_t> vocab;
|
||||
std::vector<std::pair<std::string_view, std::string_view>> merges;
|
||||
|
||||
std::map<std::string, int32_t> json_parse(const std::string & fname) {
|
||||
std::map<std::string, int32_t> result;
|
||||
|
||||
// read file into string
|
||||
std::string json;
|
||||
{
|
||||
std::ifstream ifs(fname);
|
||||
if (!ifs) {
|
||||
fprintf(stderr, "Failed to open %s\n", fname.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
json = std::string((std::istreambuf_iterator<char>(ifs)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
}
|
||||
|
||||
if (json[0] != '{') {
|
||||
return result;
|
||||
}
|
||||
|
||||
// parse json
|
||||
{
|
||||
bool has_key = false;
|
||||
bool in_token = false;
|
||||
|
||||
std::string str_key = "";
|
||||
std::string str_val = "";
|
||||
|
||||
int n = json.size();
|
||||
for (int i = 1; i < n; ++i) {
|
||||
if (!in_token) {
|
||||
if (json[i] == ' ') continue;
|
||||
if (json[i] == '"') {
|
||||
in_token = true;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (json[i] == '\\' && i+1 < n) {
|
||||
if (has_key == false) {
|
||||
str_key += json[i];
|
||||
} else {
|
||||
str_val += json[i];
|
||||
}
|
||||
++i;
|
||||
} else if (json[i] == '"') {
|
||||
if (has_key == false) {
|
||||
has_key = true;
|
||||
++i;
|
||||
while (json[i] == ' ') ++i;
|
||||
++i; // :
|
||||
while (json[i] == ' ') ++i;
|
||||
if (json[i] != '\"') {
|
||||
while (json[i] != ',' && json[i] != '}') {
|
||||
str_val += json[i++];
|
||||
}
|
||||
has_key = false;
|
||||
} else {
|
||||
in_token = true;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
has_key = false;
|
||||
}
|
||||
|
||||
::replace(str_key, "\\u0120", " " ); // \u0120 -> space
|
||||
::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
|
||||
::replace(str_key, "\\\"", "\""); // \\\" -> "
|
||||
|
||||
try {
|
||||
result[str_key] = std::stoi(str_val);
|
||||
} catch (...) {
|
||||
//fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
|
||||
|
||||
}
|
||||
str_key = "";
|
||||
str_val = "";
|
||||
in_token = false;
|
||||
continue;
|
||||
}
|
||||
if (has_key == false) {
|
||||
str_key += json[i];
|
||||
} else {
|
||||
str_val += json[i];
|
||||
}
|
||||
uint32_t tok_id = 0;
|
||||
switch (ttype) {
|
||||
case TokenizerType::MPT_CHAT:
|
||||
avis.push_back({ .id = 50277, .content = std::string_view("<|im_start|>"), .special = true });
|
||||
avis.push_back({ .id = 50278, .content = std::string_view("<|im_end|>"), .special = true });
|
||||
case TokenizerType::MPT:
|
||||
for (auto avi_e: mpt_additional_vocab) {
|
||||
avis.push_back({avi_e.id, avi_e.content.into(mpt_buffer), avi_e.special});
|
||||
}
|
||||
}
|
||||
for (auto merge: mpt_merges) {
|
||||
merges.push_back({merge.first.into(mpt_buffer), merge.second.into(mpt_buffer)});
|
||||
}
|
||||
for (auto bufref: mpt_vocab) {
|
||||
vocab.insert({bufref.into(mpt_buffer), tok_id++});
|
||||
}
|
||||
break;
|
||||
case TokenizerType::GPTJ:
|
||||
for (auto avi_e: gptj_additional_vocab) {
|
||||
avis.push_back({avi_e.id, avi_e.content.into(gptj_buffer), avi_e.special});
|
||||
}
|
||||
for (auto merge: gptj_merges) {
|
||||
merges.push_back({merge.first.into(gptj_buffer), merge.second.into(gptj_buffer)});
|
||||
}
|
||||
for (auto bufref: gptj_vocab) {
|
||||
vocab.insert({bufref.into(gptj_buffer), tok_id++});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("invalid tokenizer type");
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize_inner(const gpt_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string &s) {
|
||||
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
|
||||
return std::regex_replace(s, metacharacters, "\\$&");
|
||||
}
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||
// Generate the subpattern from the special_tokens vector if it's not empty
|
||||
if (!vocab.special_tokens.empty()) {
|
||||
std::vector<gpt_vocab::id> out;
|
||||
std::vector<std::string> chunks;
|
||||
std::string str = text;
|
||||
std::string special_tokens_subpattern;
|
||||
for (const auto &token : vocab.special_tokens) {
|
||||
if (!special_tokens_subpattern.empty()) {
|
||||
special_tokens_subpattern += "|";
|
||||
}
|
||||
special_tokens_subpattern += regex_escape(token);
|
||||
}
|
||||
std::regex re(special_tokens_subpattern);
|
||||
std::smatch m;
|
||||
while (std::regex_search(str, m, re)) {
|
||||
auto tok = vocab.token_to_id.find(m.str());
|
||||
if (tok != vocab.token_to_id.end()) {
|
||||
auto tokid = tok->second;
|
||||
auto pfxtoks = gpt_tokenize_inner(vocab, m.prefix());
|
||||
out.insert(out.end(), pfxtoks.begin(), pfxtoks.end());
|
||||
out.push_back(tokid);
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
if (!str.empty()) {
|
||||
auto tokrest = gpt_tokenize_inner(vocab, str);
|
||||
out.insert(out.end(), tokrest.begin(), tokrest.end());
|
||||
}
|
||||
return out;
|
||||
} else {
|
||||
return gpt_tokenize_inner(vocab, text);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
||||
|
||||
vocab.token_to_id = ::json_parse(fname);
|
||||
|
||||
for (const auto & kv : vocab.token_to_id) {
|
||||
vocab.id_to_token[kv.second] = kv.first;
|
||||
}
|
||||
|
||||
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
|
||||
|
||||
// print the vocabulary
|
||||
//for (auto kv : vocab.token_to_id) {
|
||||
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
|
||||
//}
|
||||
|
||||
return true;
|
||||
av = std::make_unique<bpecpp::AdditionalVocabAdapter>(avis);
|
||||
bpe = std::make_unique<bpecpp::BPE>(vocab, merges);
|
||||
}
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
@@ -313,4 +142,4 @@ gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user