chat: faster KV shift, continue generating, fix stop sequences (#2781)

* Don't stop generating at end of context
* Use llama_kv_cache ops to shift context
* Fix and improve reverse prompt detection
* Replace prompt recalc callback with a flag to disallow context shift
This commit is contained in:
Jared Van Bortel
2024-08-07 11:25:24 -04:00
committed by GitHub
parent 90de2d32f8
commit be66ec8ab5
16 changed files with 285 additions and 230 deletions

View File

@@ -33,7 +33,7 @@ set(LLMODEL_VERSION_PATCH 0)
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 23)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
set(BUILD_SHARED_LIBS ON)

View File

@@ -531,10 +531,7 @@ size_t LLamaModel::restoreState(const uint8_t *src)
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
{
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || (
llama_token_get_attr(d_ptr->model, m_tokenize_last_token)
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)
);
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
std::vector<LLModel::Token> fres(str.length() + 4);
int32_t fres_len = llama_tokenize_gpt4all(
d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
@@ -546,6 +543,12 @@ std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::
return fres;
}
bool LLamaModel::isSpecialToken(Token id) const
{
return llama_token_get_attr(d_ptr->model, id)
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN);
}
std::string LLamaModel::tokenToString(Token id) const
{
std::vector<char> result(8, 0);
@@ -595,6 +598,30 @@ bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &toke
return res == 0;
}
void LLamaModel::shiftContext(PromptContext &promptCtx)
{
// infinite text generation via context shifting
// erase up to n_ctx*contextErase tokens
int n_keep = shouldAddBOS();
int n_past = promptCtx.n_past;
int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase));
assert(n_discard > 0);
if (n_discard <= 0)
return;
std::cerr << "Llama: context full, swapping: n_past = " << n_past << ", n_keep = " << n_keep
<< ", n_discard = " << n_discard << "\n";
// erase the first n_discard tokens from the context
llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard);
llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard);
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
promptCtx.n_past = promptCtx.tokens.size();
}
int32_t LLamaModel::contextLength() const
{
return llama_n_ctx(d_ptr->ctx);

View File

@@ -6,7 +6,6 @@
#include "llmodel.h"
#include <functional>
#include <memory>
#include <string>
#include <vector>
@@ -54,9 +53,11 @@ private:
protected:
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override;
bool isSpecialToken(Token id) const override;
std::string tokenToString(Token id) const override;
Token sampleToken(PromptContext &ctx) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
void shiftContext(PromptContext &promptCtx) override;
int32_t contextLength() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override;

View File

@@ -134,7 +134,7 @@ public:
int32_t n_batch = 9;
float repeat_penalty = 1.10f;
int32_t repeat_last_n = 64; // last n tokens to penalize
float contextErase = 0.75f; // percent of context to erase if we exceed the context window
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
};
using ProgressCallback = std::function<bool(float progress)>;
@@ -159,7 +159,7 @@ public:
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &ctx,
bool special = false,
std::string *fakeReply = nullptr);
@@ -213,9 +213,11 @@ protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0;
virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0;
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0;
virtual int32_t contextLength() const = 0;
virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0;
@@ -232,10 +234,6 @@ protected:
return -1;
}
// This is a helper function called from the default implementation of 'prompt' but it can be
// shared by all base classes so it isn't virtual
void recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate);
const Implementation *m_implementation = nullptr;
ProgressCallback m_progressCallback;
@@ -249,11 +247,11 @@ protected:
bool decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx);
Token m_tokenize_last_token = -1; // not serialized

View File

@@ -106,7 +106,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
bool allow_context_shift,
llmodel_prompt_context *ctx,
bool special,
const char *fake_reply)
@@ -135,7 +135,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
// Call the C++ prompt method
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, recalculate_callback,
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
wrapper->promptContext, special, fake_reply_p);
// Update the C context by giving access to the wrappers raw pointers to std::vector data

View File

@@ -74,13 +74,6 @@ typedef bool (*llmodel_prompt_callback)(int32_t token_id);
*/
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
/**
* Callback type for recalculation of context.
* @param whether the model is recalculating the context.
* @return a bool indicating whether the model should keep generating.
*/
typedef bool (*llmodel_recalculate_callback)(bool is_recalculating);
/**
* Embedding cancellation callback for use with llmodel_embed.
* @param batch_sizes The number of tokens in each batch that will be embedded.
@@ -175,7 +168,7 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
* @param prompt_template A string representing the input prompt template.
* @param prompt_callback A callback function for handling the processing of prompt.
* @param response_callback A callback function for handling the generated response.
* @param recalculate_callback A callback function for handling recalculation requests.
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
* @param special True if special tokens in the prompt should be processed, false otherwise.
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
* @param ctx A pointer to the llmodel_prompt_context structure.
@@ -184,7 +177,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
bool allow_context_shift,
llmodel_prompt_context *ctx,
bool special,
const char *fake_reply);

View File

@@ -11,42 +11,9 @@
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_set>
#include <vector>
// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is)
// FIXME(jared): if recalculate returns false, we leave n_past<tokens.size() and do not tell the caller to stop
// FIXME(jared): if we get here during chat name or follow-up generation, bad things will happen when we try to restore
// the old prompt context afterwards
void LLModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
{
int n_keep = shouldAddBOS();
const int32_t n_discard = (promptCtx.n_ctx - n_keep) * promptCtx.contextErase;
// Erase the first percentage of context from the tokens
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
size_t i = n_keep;
promptCtx.n_past = n_keep;
while (i < promptCtx.tokens.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
std::vector<int32_t> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
if (!evalTokens(promptCtx, batch)) {
std::cerr << "LLModel ERROR: Failed to process prompt\n";
goto stop_generating;
}
promptCtx.n_past += batch.size();
if (!recalculate(true))
goto stop_generating;
i = batch_end;
}
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
stop_generating:
recalculate(false);
}
namespace ranges = std::ranges;
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err)
{
@@ -75,7 +42,7 @@ void LLModel::prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::string *fakeReply)
@@ -92,12 +59,21 @@ void LLModel::prompt(const std::string &prompt,
return;
}
// make sure token cache matches decode offset
if (promptCtx.tokens.size() < promptCtx.n_past) {
// sanity checks
if (promptCtx.n_past > contextLength()) {
std::ostringstream ss;
ss << "expected n_past to be at most " << promptCtx.tokens.size() << ", got " << promptCtx.n_past;
ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength();
throw std::out_of_range(ss.str());
}
if (promptCtx.n_past > promptCtx.tokens.size()) {
std::ostringstream ss;
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << promptCtx.tokens.size();
throw std::out_of_range(ss.str());
}
promptCtx.n_ctx = contextLength();
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
if (promptCtx.n_past < promptCtx.tokens.size())
promptCtx.tokens.resize(promptCtx.n_past);
m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized
@@ -149,15 +125,15 @@ void LLModel::prompt(const std::string &prompt,
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
// decode the user prompt
if (!decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp))
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
return; // error
// decode the assistant's reply, either generated or spoofed
if (fakeReply == nullptr) {
generateResponse(responseCallback, recalculateCallback, promptCtx);
generateResponse(responseCallback, allowContextShift, promptCtx);
} else {
embd_inp = tokenize(promptCtx, *fakeReply, false);
if (!decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp))
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
return; // error
}
@@ -172,19 +148,16 @@ void LLModel::prompt(const std::string &prompt,
}
if (!asstSuffix.empty()) {
embd_inp = tokenize(promptCtx, asstSuffix, true);
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
}
}
// returns false on error
bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp) {
// save the context size
promptCtx.n_ctx = contextLength();
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
@@ -192,9 +165,14 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
return false;
}
promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size());
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx);
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
// FIXME(jared): There are mitigations for this situation, such as making room before
// copying the prompt context, or restoring the KV cache when we restore the prompt
// context.
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > promptCtx.n_ctx) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size()
<< ", n_ctx=" << promptCtx.n_ctx << "\n";
return false;
}
// process the prompt in batches
size_t i = 0;
@@ -204,7 +182,8 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
// Check if the context has run out...
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
recalculateContext(promptCtx, recalculateCallback);
assert(allowContextShift);
shiftContext(promptCtx);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
}
@@ -226,70 +205,170 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
return true;
}
/*
* If string s overlaps with the string key such that some prefix of the key is at the end
* of the string, return the position in s where the first match starts. Otherwise, return
* std::string::npos. Examples:
* s = "bfo", key = "foo" -> 1
* s = "fooa", key = "foo" -> npos
*/
static std::string::size_type stringsOverlap(const std::string &s, const std::string &key)
{
if (s.empty() || key.empty())
throw std::invalid_argument("arguments to stringsOverlap must not be empty");
for (int start = std::max(0, int(s.size()) - int(key.size())); start < s.size(); start++) {
if (s.compare(start, s.size(), key, 0, s.size() - start) == 0)
return start;
}
return std::string::npos;
}
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx) {
static const char *stopSequences[] {
"### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context",
};
// Don't even start if there is no room
if (!promptCtx.n_predict)
return;
if (!allowContextShift && promptCtx.n_past >= promptCtx.n_ctx) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << promptCtx.n_ctx
<< "\n";
return;
}
std::string cachedResponse;
std::vector<Token> cachedTokens;
std::unordered_set<std::string> reversePrompts
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
int n_predicted = 0;
// predict next tokens
for (int i = 0; i < promptCtx.n_predict; i++) {
// Predict next tokens
for (bool stop = false; !stop;) {
// Sample next token
std::optional<Token> new_tok = sampleToken(promptCtx);
std::string new_piece = tokenToString(new_tok.value());
cachedTokens.push_back(new_tok.value());
cachedResponse += new_piece;
// sample next token
auto id = sampleToken(promptCtx);
auto accept = [this, &promptCtx, &cachedTokens, &new_tok, allowContextShift]() -> bool {
// Shift context if out of space
if (promptCtx.n_past >= promptCtx.n_ctx) {
(void)allowContextShift;
assert(allowContextShift);
shiftContext(promptCtx);
assert(promptCtx.n_past < promptCtx.n_ctx);
}
// Check if the context has run out...
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
}
// Accept the token
Token tok = std::exchange(new_tok, std::nullopt).value();
if (!evalTokens(promptCtx, { tok })) {
// TODO(jared): raise an exception
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
return false;
}
if (!evalTokens(promptCtx, { id })) {
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
return;
}
promptCtx.tokens.push_back(tok);
promptCtx.n_past += 1;
return true;
};
// display text
// Check for EOS
auto lengthLimit = std::string::npos;
for (const auto token : endTokens()) {
if (id == token) return;
}
const std::string str = tokenToString(id);
// Check if the provided str is part of our reverse prompts
bool foundPartialReversePrompt = false;
const std::string completed = cachedResponse + std::string(str);
if (reversePrompts.find(completed) != reversePrompts.end())
return;
// Check if it partially matches our reverse prompts and if so, cache
for (const auto& s : reversePrompts) {
if (s.compare(0, completed.size(), completed) == 0) {
foundPartialReversePrompt = true;
cachedResponse = completed;
break;
if (new_tok == token) {
stop = true;
lengthLimit = cachedResponse.size() - new_piece.size();
}
}
// Regardless the token gets added to our cache
cachedTokens.push_back(id);
if (lengthLimit != std::string::npos) {
// EOS matched
} else if (!isSpecialToken(new_tok.value())) {
// Check if the response contains a stop sequence
for (const auto &p : stopSequences) {
auto match = cachedResponse.find(p);
if (match != std::string::npos) stop = true;
lengthLimit = std::min(lengthLimit, match);
if (match == 0) break;
}
// Continue if we have found a partial match
if (foundPartialReversePrompt)
continue;
// Empty the cache
for (auto t : cachedTokens) {
promptCtx.tokens.push_back(t);
promptCtx.n_past += 1;
//TODO: Conversion to std::string can be avoided here...
if (!responseCallback(t, std::string(tokenToString(t))))
return;
// Check if the response matches the start of a stop sequence
if (lengthLimit == std::string::npos) {
for (const auto &p : stopSequences) {
auto match = stringsOverlap(cachedResponse, p);
lengthLimit = std::min(lengthLimit, match);
if (match == 0) break;
}
}
} else if (ranges::contains(stopSequences, new_piece)) {
// Special tokens must exactly match a stop sequence
stop = true;
lengthLimit = cachedResponse.size() - new_piece.size();
}
// Optionally stop if the context will run out
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= promptCtx.n_ctx) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx="
<< promptCtx.n_ctx << "\n";
stop = true;
}
// Empty the cache, up to the length limit
std::string::size_type responseLength = 0;
while (!cachedTokens.empty()) {
Token tok = cachedTokens.front();
std::string piece = tokenToString(tok);
// Stop if the piece (or part of it) does not fit within the length limit
if (responseLength + (stop ? 1 : piece.size()) > lengthLimit)
break;
// Remove token from cache
assert(cachedResponse.starts_with(piece));
cachedTokens.erase(cachedTokens.begin(), cachedTokens.begin() + 1);
cachedResponse.erase(cachedResponse.begin(), cachedResponse.begin() + piece.size());
// Accept the token, if needed (not cached)
if (cachedTokens.empty() && new_tok && !accept())
return;
// Send the token
if (!responseCallback(tok, piece) || ++n_predicted >= promptCtx.n_predict) {
stop = true;
break;
}
// FIXME(jared): we could avoid printing partial stop sequences if we didn't have to
// output token IDs and could cache a partial token for the next prompt call
responseLength += piece.size();
}
assert(cachedTokens.empty() == cachedResponse.empty());
// Accept the token, if needed (in cache)
if (new_tok) {
assert(!cachedTokens.empty() && cachedTokens.back() == new_tok);
if (stop) {
cachedTokens.pop_back();
} else if (!accept()) {
return;
}
}
cachedTokens.clear();
}
auto &tokens = promptCtx.tokens;
if (tokens.size() < cachedTokens.size()) {
/* This is theoretically possible if the longest stop sequence is greater than
* n_ctx * contextErase tokens. */
throw std::runtime_error("shifted too much context, can't go back");
}
auto discard_start = tokens.end() - cachedTokens.size();
assert(std::equal(discard_start, tokens.end(), cachedTokens.begin()));
tokens.erase(discard_start, tokens.end());
promptCtx.n_past -= cachedTokens.size();
}
void LLModel::embed(