Revert "New tokenizer implementation for MPT and GPT-J"

This reverts commit ee3469ba6c.
This commit is contained in:
Adam Treat
2023-05-30 12:59:00 -04:00
parent 06434f0042
commit 4a317eeb33
13 changed files with 241 additions and 47164 deletions

View File

@@ -7,7 +7,6 @@
#include <cmath>
#include <cstdio>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <map>
#include <random>
@@ -786,12 +785,6 @@ bool MPT::loadModel(const std::string &modelPath) {
d_ptr->modelLoaded = true;
d_ptr->has_im_end = d_ptr->vocab.token_to_id.find("<|im_end|>") != d_ptr->vocab.token_to_id.end();
fflush(stdout);
if (modelPath.find("-chat") != std::string::npos) {
get_bpecpp_tokenizer(TokenizerType::MPT_CHAT, m_bpe, m_tokav);
} else {
get_bpecpp_tokenizer(TokenizerType::MPT, m_bpe, m_tokav);
}
return true;
}
@@ -847,7 +840,7 @@ void MPT::prompt(const std::string &prompt,
int64_t t_prompt_us = 0;
// tokenize the prompt
std::vector<uint32_t> embd_inp = m_tokav->encode(prompt, *m_bpe);
std::vector<int> embd_inp = gpt_tokenize(d_ptr->vocab, prompt);
// save the context size
promptCtx.n_ctx = d_ptr->model->hparams.n_ctx;
@@ -913,7 +906,6 @@ void MPT::prompt(const std::string &prompt,
int r_instructFound = 0;
std::string cachedResponse;
std::string decodeBuffer;
std::vector<int> cachedTokens;
std::unordered_set<std::string> reversePrompts
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
@@ -969,7 +961,7 @@ void MPT::prompt(const std::string &prompt,
if (id == 0 /*end of text*/)
goto stop_generating;
const std::string str = m_tokav->decode({(uint32_t) id}, *m_bpe, true, false);
const std::string str = d_ptr->vocab.id_to_token[id];
// Check if the provided str is part of our reverse prompts
bool foundPartialReversePrompt = false;
@@ -999,8 +991,7 @@ void MPT::prompt(const std::string &prompt,
if (promptCtx.tokens.size() == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(t);
const std::string decoded = m_tokav->decode({(uint32_t) t}, *m_bpe, true, false);
if (!responseCallback(t, decoded))
if (!responseCallback(t, d_ptr->vocab.id_to_token[t]))
goto stop_generating;
}
cachedTokens.clear();