Use the token cache to infer greater n_past and reuse results (#3073)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-10-31 11:19:12 -04:00
committed by GitHub
parent 62cab695eb
commit f07e2e63df
15 changed files with 320 additions and 169 deletions

View File

@@ -124,9 +124,7 @@ public:
};
struct PromptContext {
std::vector<int32_t> tokens; // current tokens in the context window
int32_t n_past = 0; // number of tokens in past conversation
int32_t n_ctx = 0; // number of tokens possible in context window
int32_t n_predict = 200;
int32_t top_k = 40;
float top_p = 0.9f;
@@ -151,8 +149,8 @@ public:
virtual bool isModelLoaded() const = 0;
virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0;
virtual size_t stateSize() const = 0;
virtual size_t saveState(std::span<uint8_t> dest) const = 0;
virtual size_t restoreState(std::span<const uint8_t> src) = 0;
virtual size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const = 0;
virtual size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) = 0;
// This method requires the model to return true from supportsCompletion otherwise it will throw
// an error
@@ -210,6 +208,8 @@ public:
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }
virtual int32_t contextLength() const = 0;
protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
@@ -218,9 +218,15 @@ protected:
virtual std::string tokenToString(Token id) const = 0;
virtual void initSampler(PromptContext &ctx) = 0;
virtual Token sampleToken() const = 0;
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
virtual bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0;
virtual int32_t contextLength() const = 0;
virtual int32_t inputLength() const = 0;
virtual void setTokenizeInputPosition(int32_t pos) = 0;
virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator = 0;
virtual void setModelInputPosition(PromptContext &ctx, int32_t pos) = 0;
virtual void appendInputToken(PromptContext &ctx, Token tok) = 0;
virtual std::span<const Token> inputTokens() const = 0;
virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0;
@@ -252,11 +258,13 @@ protected:
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp,
bool isResponse = false);
bool isResponse = false,
bool alwaysDecode = false);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx);
protected:
Token m_tokenize_last_token = -1; // not serialized
friend class LLMImplementation;

View File

@@ -23,6 +23,11 @@ extern "C" {
*/
typedef void *llmodel_model;
/**
* A token.
*/
typedef int32_t token_t;
/**
* llmodel_prompt_context structure for holding the prompt context.
* NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the
@@ -30,10 +35,7 @@ typedef void *llmodel_model;
* behavior.
*/
struct llmodel_prompt_context {
int32_t *tokens; // current tokens in the context window
size_t tokens_size; // the size of the raw tokens vector
int32_t n_past; // number of tokens in past conversation
int32_t n_ctx; // number of tokens possible in context window
int32_t n_predict; // number of tokens to predict
int32_t top_k; // top k logits to sample from
float top_p; // nucleus sampling probability threshold
@@ -141,27 +143,41 @@ bool llmodel_isModelLoaded(llmodel_model model);
* @param model A pointer to the llmodel_model instance.
* @return the size in bytes of the internal state of the model
*/
uint64_t llmodel_get_state_size(llmodel_model model);
uint64_t llmodel_state_get_size(llmodel_model model);
/**
* Saves the internal state of the model to the specified destination address.
* Saves the internal state of the model.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param dest A pointer to the destination.
* @param size The size of the destination buffer.
* @return the number of bytes copied, or zero on error.
* @param state Where to store the state. This must be a buffer of at least llmodel_state_get_size() bytes.
* @param state_size The size of the destination for the state.
* @param input_tokens_out Where to store the address of the token cache state. This is dynamically allocated and must
* be freed with llmodel_state_free_input_tokens.
* @param n_input_tokens Where to store the size of the token cache state.
* @return The number of bytes copied. On error, zero is returned, the token cache is set to NULL, and the token cache
* size is set to zero.
*/
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size);
uint64_t llmodel_state_get_data(llmodel_model model, uint8_t *state_out, uint64_t state_size,
token_t **input_tokens_out, uint64_t *n_input_tokens);
/**
* Frees the temporary token cache buffer created by a call to llmodel_state_get_data().
* @param input_tokens The token cache buffer.
*/
void llmodel_state_free_input_tokens(token_t *input_tokens);
/**
* Restores the internal state of the model using data from the specified address.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param src A pointer to the state data.
* @param size The size of the source data.
* @param state A pointer to the state data.
* @param state_size The size of the state data.
* @param input_tokens The token cache associated with the saved state.
* @param n_input_tokens The number of tokens in input_tokens.
* @return The number of bytes read, or zero on error.
*/
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, size_t size);
uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint64_t state_size,
const token_t *input_tokens, uint64_t n_input_tokens);
/**
* Generate a response using the model.

View File

@@ -218,6 +218,7 @@ struct LLamaPrivate {
int64_t n_threads = 0;
std::vector<LLModel::Token> end_tokens;
const char *backend_name = nullptr;
std::vector<LLModel::Token> inputTokens;
llama_model *model = nullptr;
llama_context *ctx = nullptr;
@@ -501,14 +502,20 @@ size_t LLamaModel::stateSize() const
return llama_state_get_size(d_ptr->ctx);
}
size_t LLamaModel::saveState(std::span<uint8_t> dest) const
size_t LLamaModel::saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const
{
return llama_state_get_data(d_ptr->ctx, dest.data(), dest.size());
size_t bytesWritten = llama_state_get_data(d_ptr->ctx, stateOut.data(), stateOut.size());
if (bytesWritten)
inputTokensOut.assign(d_ptr->inputTokens.begin(), d_ptr->inputTokens.end());
return bytesWritten;
}
size_t LLamaModel::restoreState(std::span<const uint8_t> src)
size_t LLamaModel::restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens)
{
return llama_state_set_data(d_ptr->ctx, src.data(), src.size());
size_t bytesRead = llama_state_set_data(d_ptr->ctx, state.data(), state.size());
if (bytesRead)
d_ptr->inputTokens.assign(inputTokens.begin(), inputTokens.end());
return bytesRead;
}
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str, bool special)
@@ -594,7 +601,7 @@ LLModel::Token LLamaModel::sampleToken() const
return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1);
}
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) const
{
llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1);
@@ -625,7 +632,7 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
// erase up to n_ctx*contextErase tokens
int n_keep = shouldAddBOS();
int n_past = promptCtx.n_past;
int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase));
int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase));
assert(n_discard > 0);
if (n_discard <= 0)
@@ -638,8 +645,9 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard);
llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard);
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
promptCtx.n_past = promptCtx.tokens.size();
auto &inp = d_ptr->inputTokens;
inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard);
promptCtx.n_past = inp.size();
}
int32_t LLamaModel::contextLength() const
@@ -647,6 +655,60 @@ int32_t LLamaModel::contextLength() const
return llama_n_ctx(d_ptr->ctx);
}
int32_t LLamaModel::inputLength() const
{
return d_ptr->inputTokens.size();
}
void LLamaModel::setTokenizeInputPosition(int32_t pos)
{
assert(pos >= 0);
m_tokenize_last_token = pos ? d_ptr->inputTokens.at(size_t(pos) - 1) : -1; // not serialized
}
auto LLamaModel::computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator
{
assert(ctx.n_past >= 0);
auto pos = size_t(ctx.n_past);
if (pos > d_ptr->inputTokens.size()) {
std::ostringstream ss;
ss << "n_past=" << pos << " is past end of token cache length=" << d_ptr->inputTokens.size();
throw std::out_of_range(ss.str());
}
// find common prefix
auto cacheIt = d_ptr->inputTokens.begin();
auto inputIt = input.begin();
while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) {
++cacheIt; ++inputIt; ++pos;
}
// tell the caller to ignore the tokens between [begin, inputIt)
return inputIt;
}
void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
{
auto &inp = d_ptr->inputTokens;
assert(pos >= 0);
assert(pos <= inp.size());
// truncate token cache to end at the new n_past
if (pos < inp.size())
inp.resize(pos);
ctx.n_past = pos;
}
void LLamaModel::appendInputToken(PromptContext &ctx, Token tok)
{
d_ptr->inputTokens.push_back(tok);
ctx.n_past += 1;
}
auto LLamaModel::inputTokens() const -> std::span<const Token>
{
return d_ptr->inputTokens;
}
const std::vector<LLModel::Token> &LLamaModel::endTokens() const
{
return d_ptr->end_tokens;

View File

@@ -28,8 +28,8 @@ public:
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
size_t stateSize() const override;
size_t saveState(std::span<uint8_t> dest) const override;
size_t restoreState(std::span<const uint8_t> src) override;
size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const override;
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;
std::vector<GPUDevice> availableGPUDevices(size_t memoryRequired = 0) const override;
@@ -48,10 +48,7 @@ public:
void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality = -1,
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
private:
std::unique_ptr<LLamaPrivate> d_ptr;
bool m_supportsEmbedding = false;
bool m_supportsCompletion = false;
int32_t contextLength() const override;
protected:
std::vector<Token> tokenize(std::string_view str, bool special) override;
@@ -59,9 +56,15 @@ protected:
std::string tokenToString(Token id) const override;
void initSampler(PromptContext &ctx) override;
Token sampleToken() const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override;
void shiftContext(PromptContext &promptCtx) override;
int32_t contextLength() const override;
int32_t inputLength() const override;
void setTokenizeInputPosition(int32_t pos) override;
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator override;
void setModelInputPosition(PromptContext &ctx, int32_t pos) override;
void appendInputToken(PromptContext &ctx, Token tok) override;
std::span<const Token> inputTokens() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override;
int32_t maxContextLength(std::string const &modelPath) const override;
@@ -70,6 +73,11 @@ protected:
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,
const EmbModelSpec *spec);
private:
std::unique_ptr<LLamaPrivate> d_ptr;
bool m_supportsEmbedding = false;
bool m_supportsCompletion = false;
};
#endif // LLAMAMODEL_H

View File

@@ -14,6 +14,11 @@
#include <string>
#include <string_view>
#include <vector>
#include <span>
namespace ranges = std::ranges;
static_assert(sizeof(token_t) == sizeof(LLModel::Token));
struct LLModelWrapper {
LLModel *llModel = nullptr;
@@ -85,22 +90,40 @@ bool llmodel_isModelLoaded(llmodel_model model)
return wrapper->llModel->isModelLoaded();
}
uint64_t llmodel_get_state_size(llmodel_model model)
uint64_t llmodel_state_get_size(llmodel_model model)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->stateSize();
}
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size)
uint64_t llmodel_state_get_data(llmodel_model model, uint8_t *state_out, uint64_t state_size,
token_t **input_tokens_out, uint64_t *n_input_tokens)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->saveState({dest, size_t(size)});
std::vector<LLModel::Token> inputTokens;
auto bytesWritten = wrapper->llModel->saveState({state_out, size_t(state_size)}, inputTokens);
if (bytesWritten) {
auto *buf = new LLModel::Token[inputTokens.size()];
ranges::copy(inputTokens, buf);
*input_tokens_out = buf;
*n_input_tokens = uint64_t(inputTokens.size());
} else {
*input_tokens_out = nullptr;
*n_input_tokens = 0;
}
return bytesWritten;
}
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, uint64_t size)
void llmodel_state_free_input_tokens(LLModel::Token *input_tokens)
{
delete[] input_tokens;
}
uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint64_t state_size,
const token_t *input_tokens, uint64_t n_input_tokens)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->restoreState({src, size_t(size)});
return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)});
}
void llmodel_prompt(llmodel_model model, const char *prompt,
@@ -120,7 +143,6 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
// Copy the C prompt context
wrapper->promptContext.n_past = ctx->n_past;
wrapper->promptContext.n_ctx = ctx->n_ctx;
wrapper->promptContext.n_predict = ctx->n_predict;
wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p;
@@ -136,14 +158,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
wrapper->promptContext, special,
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies
ctx->tokens = wrapper->promptContext.tokens.data();
ctx->tokens_size = wrapper->promptContext.tokens.size();
// Update the rest of the C prompt context
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_ctx = wrapper->promptContext.n_ctx;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;

View File

@@ -6,6 +6,7 @@
#include <cstdint>
#include <functional>
#include <iostream>
#include <iterator>
#include <optional>
#include <regex>
#include <sstream>
@@ -66,19 +67,14 @@ void LLModel::prompt(const std::string &prompt,
ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength();
throw std::out_of_range(ss.str());
}
if (promptCtx.n_past > promptCtx.tokens.size()) {
if (promptCtx.n_past > inputLength()) {
std::ostringstream ss;
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << promptCtx.tokens.size();
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << inputLength();
throw std::out_of_range(ss.str());
}
promptCtx.n_ctx = contextLength();
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
if (promptCtx.n_past < promptCtx.tokens.size())
promptCtx.tokens.resize(promptCtx.n_past);
m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized
// parse the prompt template
std::vector<std::smatch> placeholders;
{
@@ -90,6 +86,8 @@ void LLModel::prompt(const std::string &prompt,
}
}
setTokenizeInputPosition(promptCtx.n_past);
// tokenize the user prompt
std::vector<Token> embd_inp;
if (placeholders.empty()) {
@@ -118,7 +116,8 @@ void LLModel::prompt(const std::string &prompt,
}
// decode the user prompt
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, /*isResponse*/ false,
/*alwaysDecode*/ true))
return; // error
// decode the assistant's reply, either generated or spoofed
@@ -151,36 +150,67 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp,
bool isResponse) {
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
bool isResponse,
bool alwaysDecode) {
if ((int) embd_inp.size() > contextLength() - 4) {
// FIXME: (Adam) We should find a way to bubble these strings to the UI level to allow for
// translation
responseCallback(-1, "Your message was too long and could not be processed. Please try again with something shorter.");
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
" tokens and the context window is " << promptCtx.n_ctx << "!\n";
" tokens and the context window is " << contextLength() << "!\n";
return false;
}
// FIXME(jared): There are mitigations for this situation, such as making room before
// copying the prompt context, or restoring the KV cache when we restore the prompt
// context.
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > promptCtx.n_ctx) {
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size()
<< ", n_ctx=" << promptCtx.n_ctx << "\n";
<< ", n_ctx=" << contextLength() << "\n";
return false;
}
// process the prompt in batches
// always decode something before generating, even if cached
if (alwaysDecode && embd_inp.empty()) {
auto cache = inputTokens();
if (!promptCtx.n_past)
throw std::runtime_error("zero token prompt is not supported");
assert(!cache.empty());
embd_inp.push_back(cache.back());
promptCtx.n_past--;
}
// Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the
// requested n_past.
// This is used to skip unnecessary work when the prompt shares a common prefix with the previous result.
auto embd_inp_start = computeModelInputPosition(promptCtx, embd_inp);
size_t start_offset = embd_inp_start - embd_inp.begin();
// always decode up to a full batch before generating, even if cached
if (alwaysDecode)
start_offset -= std::min(promptCtx.n_batch, int32_t(start_offset));
setModelInputPosition(promptCtx, promptCtx.n_past + start_offset);
// execute the callback even for skipped tokens
size_t i = 0;
for (; i < start_offset; i++) {
Token tok = embd_inp[i];
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
if (!res)
return false;
}
// process the prompt in batches
while (i < embd_inp.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
std::vector<Token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
std::span<const Token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
// Check if the context has run out...
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
if (promptCtx.n_past + int32_t(batch.size()) > contextLength()) {
assert(allowContextShift);
shiftContext(promptCtx);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
assert(promptCtx.n_past + int32_t(batch.size()) <= contextLength());
}
if (!evalTokens(promptCtx, batch)) {
@@ -190,9 +220,8 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t) {
promptCtx.tokens.push_back(batch.at(t));
promptCtx.n_past += 1;
Token tok = batch.at(t);
Token tok = batch[t];
appendInputToken(promptCtx, tok);
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
if (!res)
return false;
@@ -232,8 +261,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
// Don't even start if there is no room
if (!promptCtx.n_predict)
return;
if (!allowContextShift && promptCtx.n_past >= promptCtx.n_ctx) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << promptCtx.n_ctx
if (!allowContextShift && promptCtx.n_past >= contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << contextLength()
<< "\n";
return;
}
@@ -254,23 +283,22 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
auto accept = [this, &promptCtx, &new_tok, allowContextShift]() -> bool {
// Shift context if out of space
if (promptCtx.n_past >= promptCtx.n_ctx) {
if (promptCtx.n_past >= contextLength()) {
(void)allowContextShift;
assert(allowContextShift);
shiftContext(promptCtx);
assert(promptCtx.n_past < promptCtx.n_ctx);
assert(promptCtx.n_past < contextLength());
}
// Accept the token
Token tok = std::exchange(new_tok, std::nullopt).value();
if (!evalTokens(promptCtx, { tok })) {
if (!evalTokens(promptCtx, { &tok, 1 })) {
// TODO(jared): raise an exception
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
return false;
}
promptCtx.tokens.push_back(tok);
promptCtx.n_past += 1;
appendInputToken(promptCtx, tok);
return true;
};
@@ -309,9 +337,9 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
}
// Optionally stop if the context will run out
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= promptCtx.n_ctx) {
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx="
<< promptCtx.n_ctx << "\n";
<< contextLength() << "\n";
stop = true;
}
@@ -357,16 +385,17 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
}
}
auto &tokens = promptCtx.tokens;
if (tokens.size() < cachedTokens.size()) {
if (inputLength() < cachedTokens.size()) {
/* This is theoretically possible if the longest stop sequence is greater than
* n_ctx * contextErase tokens. */
throw std::runtime_error("shifted too much context, can't go back");
}
auto discard_start = tokens.end() - cachedTokens.size();
assert(std::equal(discard_start, tokens.end(), cachedTokens.begin()));
tokens.erase(discard_start, tokens.end());
#ifndef NDEBUG
auto inp = inputTokens();
auto discard_start = inp.end() - cachedTokens.size();
assert(std::equal(discard_start, inp.end(), cachedTokens.begin()));
#endif
promptCtx.n_past -= cachedTokens.size();
}