fix chat-style prompt templates (#1970)

Also use a new version of Mistral OpenOrca.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-02-21 15:45:32 -05:00
committed by GitHub
parent b8f5c74f40
commit 4fc4d94be4
22 changed files with 429 additions and 307 deletions

View File

@@ -2,11 +2,20 @@
#include <cassert>
#include <iostream>
#include <regex>
#include <unordered_set>
// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is)
void LLModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate) {
size_t i = 0;
promptCtx.n_past = 0;
int n_keep = shouldAddBOS();
const int32_t n_discard = (promptCtx.n_ctx - n_keep) * promptCtx.contextErase;
// Erase the first percentage of context from the tokens
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
size_t i = n_keep;
promptCtx.n_past = n_keep;
while (i < promptCtx.tokens.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
std::vector<int32_t> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
@@ -26,11 +35,36 @@ stop_generating:
recalculate(false);
}
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err) {
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))");
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex);
placeholders.clear();
placeholders.insert(placeholders.end(), it, std::sregex_iterator());
if (placeholders.size() > 2) {
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size());
return false;
}
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
return false;
}
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
return false;
}
return true;
}
void LLModel::prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx)
PromptContext &promptCtx,
bool special,
std::string *fakeReply)
{
if (!isModelLoaded()) {
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
@@ -38,15 +72,86 @@ void LLModel::prompt(const std::string &prompt,
}
if (!supportsCompletion()) {
std::string errorMessage = "ERROR: this model does not support text completion or chat!\n";
std::string errorMessage = "ERROR: this model does not support text completion or chat!";
responseCallback(-1, errorMessage);
std::cerr << implementation().modelType() << errorMessage;
std::cerr << implementation().modelType() << " " << errorMessage << "\n";
return;
}
// tokenize the prompt
std::vector<Token> embd_inp = tokenize(promptCtx, prompt);
// parse the prompt template
std::vector<std::smatch> placeholders;
{
std::string err;
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
responseCallback(-1, err);
std::cerr << err << "\n";
return;
}
}
auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
// tokenize the user prompt
std::vector<Token> embd_inp;
if (placeholders.empty()) {
// this is unusual, but well-defined
std::cerr << __func__ << ": prompt template has no placeholder\n";
embd_inp = tokenize(promptCtx, promptTemplate, true);
} else {
// template: beginning of user prompt
const auto &phUser = placeholders[0];
std::string userPrefix(phUser.prefix());
if (!userPrefix.empty()) {
embd_inp = tokenize(promptCtx, userPrefix, true);
promptCtx.n_past += embd_inp.size();
}
// user input (shouldn't have special token processing)
auto tokens = tokenize(promptCtx, prompt, special);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
promptCtx.n_past += tokens.size();
// template: end of user prompt + start of assistant prompt
size_t start = phUser.position() + phUser.length();
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
auto userToAsst = promptTemplate.substr(start, end - start);
if (!userToAsst.empty()) {
tokens = tokenize(promptCtx, userToAsst, true);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
promptCtx.n_past += tokens.size();
}
}
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
// decode the user prompt
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
// decode the assistant's reply, either generated or spoofed
if (fakeReply == nullptr) {
generateResponse(responseCallback, recalculateCallback, promptCtx);
} else {
embd_inp = tokenize(promptCtx, *fakeReply, false);
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
}
// decode the rest of the prompt template
if (placeholders.size() >= 2) {
// template: end of assistant prompt
size_t start = placeholders[1].position() + placeholders[1].length();
auto asstSuffix = promptTemplate.substr(start);
if (!asstSuffix.empty()) {
embd_inp = tokenize(promptCtx, asstSuffix, true);
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
}
}
}
void LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx,
std::vector<Token> embd_inp) {
// save the context size
promptCtx.n_ctx = contextLength();
@@ -69,11 +174,6 @@ void LLModel::prompt(const std::string &prompt,
// Check if the context has run out...
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens...
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
}
@@ -94,7 +194,11 @@ void LLModel::prompt(const std::string &prompt,
}
i = batch_end;
}
}
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx) {
std::string cachedResponse;
std::vector<Token> cachedTokens;
std::unordered_set<std::string> reversePrompts
@@ -108,11 +212,6 @@ void LLModel::prompt(const std::string &prompt,
// Check if the context has run out...
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens...
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
}
@@ -165,8 +264,9 @@ void LLModel::prompt(const std::string &prompt,
}
}
std::vector<float> LLModel::embedding(const std::string &/*text*/)
std::vector<float> LLModel::embedding(const std::string &text)
{
(void)text;
if (!supportsCompletion()) {
std::string errorMessage = "ERROR: this model does not support generating embeddings!\n";
std::cerr << implementation().modelType() << errorMessage;