Remove binary state from high-level API and use Jinja templates (#3147)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Signed-off-by: Adam Treat <treat.adam@gmail.com>
Co-authored-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Jared Van Bortel
2024-11-25 10:04:17 -05:00
committed by GitHub
parent 3320094d29
commit 225bf6be93
54 changed files with 3423 additions and 2224 deletions

View File

@@ -5,6 +5,7 @@
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <expected>
#include <functional>
#include <optional>
#include <span>
@@ -24,6 +25,10 @@ using namespace std::string_literals;
class LLModel {
public:
using Token = int32_t;
using PromptCallback = std::function<bool(std::span<const Token> batch, bool cached)>;
using ResponseCallback = std::function<bool(Token token, std::string_view piece)>;
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
using ProgressCallback = std::function<bool(float progress)>;
class BadArchError: public std::runtime_error {
public:
@@ -101,6 +106,7 @@ public:
static int32_t maxContextLength(const std::string &modelPath);
static int32_t layerCount(const std::string &modelPath);
static bool isEmbeddingModel(const std::string &modelPath);
static auto chatTemplate(const char *modelPath) -> std::expected<std::string, std::string>;
static void setImplementationsSearchPath(const std::string &path);
static const std::string &implementationsSearchPath();
static bool hasSupportedCPU();
@@ -124,7 +130,6 @@ public:
};
struct PromptContext {
int32_t n_past = 0; // number of tokens in past conversation
int32_t n_predict = 200;
int32_t top_k = 40;
float top_p = 0.9f;
@@ -136,8 +141,6 @@ public:
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
};
using ProgressCallback = std::function<bool(float progress)>;
explicit LLModel() {}
virtual ~LLModel() {}
@@ -154,16 +157,12 @@ public:
// This method requires the model to return true from supportsCompletion otherwise it will throw
// an error
virtual void prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &ctx,
bool special = false,
std::optional<std::string_view> fakeReply = {});
virtual void prompt(std::string_view prompt,
const PromptCallback &promptCallback,
const ResponseCallback &responseCallback,
const PromptContext &ctx);
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
virtual int32_t countPromptTokens(std::string_view prompt) const;
virtual size_t embeddingSize() const {
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
@@ -209,23 +208,22 @@ public:
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }
virtual int32_t contextLength() const = 0;
virtual auto specialTokens() -> std::unordered_map<std::string, std::string> const = 0;
protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(std::string_view str, bool special = false) = 0;
virtual std::vector<Token> tokenize(std::string_view str) const = 0;
virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual void initSampler(PromptContext &ctx) = 0;
virtual void initSampler(const PromptContext &ctx) = 0;
virtual Token sampleToken() const = 0;
virtual bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0;
virtual bool evalTokens(int32_t nPast, std::span<const Token> tokens) const = 0;
virtual void shiftContext(const PromptContext &promptCtx, int32_t *nPast) = 0;
virtual int32_t inputLength() const = 0;
virtual void setTokenizeInputPosition(int32_t pos) = 0;
virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator = 0;
virtual void setModelInputPosition(PromptContext &ctx, int32_t pos) = 0;
virtual void appendInputToken(PromptContext &ctx, Token tok) = 0;
virtual int32_t computeModelInputPosition(std::span<const Token> input) const = 0;
virtual void setModelInputPosition(int32_t pos) = 0;
virtual void appendInputToken(Token tok) = 0;
virtual std::span<const Token> inputTokens() const = 0;
virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0;
@@ -242,6 +240,12 @@ protected:
return -1;
}
virtual auto chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string>
{
(void)modelPath;
return std::unexpected("not implemented");
}
const Implementation *m_implementation = nullptr;
ProgressCallback m_progressCallback;
@@ -253,19 +257,15 @@ protected:
return true;
}
bool decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp,
bool isResponse = false,
bool alwaysDecode = false);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx);
protected:
Token m_tokenize_last_token = -1; // not serialized
// prefill context with prompt
auto decodePrompt(const PromptCallback &promptCallback,
const PromptContext &promptCtx,
std::vector<Token> embd_inp)
-> std::optional<int32_t>;
// generate a response
void generateResponse(const ResponseCallback &responseCallback,
const PromptContext &promptCtx,
int32_t nPast);
friend class LLMImplementation;
};

View File

@@ -35,16 +35,15 @@ typedef int32_t token_t;
* behavior.
*/
struct llmodel_prompt_context {
int32_t n_past; // number of tokens in past conversation
int32_t n_predict; // number of tokens to predict
int32_t top_k; // top k logits to sample from
float top_p; // nucleus sampling probability threshold
float min_p; // Min P sampling
float temp; // temperature to adjust model's output distribution
float top_p; // nucleus sampling probability threshold
float min_p; // Min P sampling
float temp; // temperature to adjust model's output distribution
int32_t n_batch; // number of predictions to generate in parallel
float repeat_penalty; // penalty factor for repeated tokens
float repeat_penalty; // penalty factor for repeated tokens
int32_t repeat_last_n; // last n tokens to penalize
float context_erase; // percent of context to erase if we exceed the context window
float context_erase; // percent of context to erase if we exceed the context window
};
struct llmodel_gpu_device {
@@ -63,10 +62,12 @@ typedef struct llmodel_gpu_device llmodel_gpu_device;
/**
* Callback type for prompt processing.
* @param token_id The token id of the prompt.
* @param token_ids An array of token ids of the prompt.
* @param n_token_ids The number of tokens in the array.
* @param cached Whether the tokens were already in cache.
* @return a bool indicating whether the model should keep processing.
*/
typedef bool (*llmodel_prompt_callback)(int32_t token_id);
typedef bool (*llmodel_prompt_callback)(const token_t *token_ids, size_t n_token_ids, bool cached);
/**
* Callback type for response.
@@ -74,7 +75,7 @@ typedef bool (*llmodel_prompt_callback)(int32_t token_id);
* @param response The response string. NOTE: a token_id of -1 indicates the string is an error string.
* @return a bool indicating whether the model should keep generating.
*/
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
typedef bool (*llmodel_response_callback)(token_t token_id, const char *response);
/**
* Embedding cancellation callback for use with llmodel_embed.
@@ -85,6 +86,8 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response
*/
typedef bool (*llmodel_emb_cancel_callback)(unsigned *batch_sizes, unsigned n_batch, const char *backend);
typedef void (*llmodel_special_token_callback)(const char *name, const char *token);
/**
* Create a llmodel instance.
* Recognises correct model type from file at model_path
@@ -183,22 +186,17 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6
* Generate a response using the model.
* @param model A pointer to the llmodel_model instance.
* @param prompt A string representing the input prompt.
* @param prompt_template A string representing the input prompt template.
* @param prompt_callback A callback function for handling the processing of prompt.
* @param response_callback A callback function for handling the generated response.
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
* @param special True if special tokens in the prompt should be processed, false otherwise.
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
* @param ctx A pointer to the llmodel_prompt_context structure.
* @param error A pointer to a string; will only be set on error.
*/
void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
bool allow_context_shift,
llmodel_prompt_context *ctx,
bool special,
const char *fake_reply);
bool llmodel_prompt(llmodel_model model,
const char *prompt,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_prompt_context *ctx,
const char **error);
/**
* Generate an embedding using the model.
@@ -310,6 +308,10 @@ const char *llmodel_model_backend_name(llmodel_model model);
*/
const char *llmodel_model_gpu_device_name(llmodel_model model);
int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error);
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback);
#ifdef __cplusplus
}
#endif

View File

@@ -202,7 +202,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const
if (keyidx != -1) {
value = gguf_get_val_u32(ctx, keyidx);
} else {
std::cerr << __func__ << ": " << key << "not found in " << modelPath << "\n";
std::cerr << __func__ << ": " << key << " not found in " << modelPath << "\n";
}
}
@@ -518,18 +518,13 @@ size_t LLamaModel::restoreState(std::span<const uint8_t> state, std::span<const
return bytesRead;
}
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str, bool special)
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str) const
{
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
std::vector<LLModel::Token> fres(str.length() + 4);
int32_t fres_len = llama_tokenize_gpt4all(
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
/*parse_special*/ special, /*insert_space*/ insertSpace
int32_t fres_len = llama_tokenize(
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ true, /*parse_special*/ true
);
fres.resize(fres_len);
if (fres_len)
m_tokenize_last_token = fres.back();
return fres;
}
@@ -555,7 +550,7 @@ std::string LLamaModel::tokenToString(Token id) const
return std::string(result.data(), result.size());
}
void LLamaModel::initSampler(PromptContext &promptCtx)
void LLamaModel::initSampler(const PromptContext &promptCtx)
{
auto *model = d_ptr->model;
auto *chain = d_ptr->sampler_chain;
@@ -601,9 +596,11 @@ LLModel::Token LLamaModel::sampleToken() const
return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1);
}
bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) const
bool LLamaModel::evalTokens(int32_t nPast, std::span<const Token> tokens) const
{
llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1);
assert(!tokens.empty());
llama_kv_cache_seq_rm(d_ptr->ctx, 0, nPast, -1);
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
@@ -611,7 +608,7 @@ bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) c
for (int32_t i = 0; i < batch.n_tokens; i++) {
batch.token [i] = tokens[i];
batch.pos [i] = ctx.n_past + i;
batch.pos [i] = nPast + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i][0] = 0;
batch.logits [i] = false;
@@ -625,13 +622,13 @@ bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) c
return res == 0;
}
void LLamaModel::shiftContext(PromptContext &promptCtx)
void LLamaModel::shiftContext(const PromptContext &promptCtx, int32_t *nPast)
{
// infinite text generation via context shifting
// erase up to n_ctx*contextErase tokens
int n_keep = shouldAddBOS();
int n_past = promptCtx.n_past;
int n_past = *nPast;
int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase));
assert(n_discard > 0);
@@ -647,7 +644,7 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
auto &inp = d_ptr->inputTokens;
inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard);
promptCtx.n_past = inp.size();
*nPast = inp.size();
}
int32_t LLamaModel::contextLength() const
@@ -655,39 +652,37 @@ int32_t LLamaModel::contextLength() const
return llama_n_ctx(d_ptr->ctx);
}
auto LLamaModel::specialTokens() -> std::unordered_map<std::string, std::string> const
{
if (!d_ptr->model)
throw std::logic_error("model not loaded");
std::unordered_map<std::string, std::string> tokens;
if (auto id = llama_token_bos(d_ptr->model); id != LLAMA_TOKEN_NULL)
tokens.emplace("bos_token", tokenToString(id));
if (auto id = llama_token_eos(d_ptr->model); id != LLAMA_TOKEN_NULL)
tokens.emplace("eos_token", tokenToString(id));
return tokens;
}
int32_t LLamaModel::inputLength() const
{
return d_ptr->inputTokens.size();
}
void LLamaModel::setTokenizeInputPosition(int32_t pos)
int32_t LLamaModel::computeModelInputPosition(std::span<const Token> input) const
{
assert(pos >= 0);
m_tokenize_last_token = pos ? d_ptr->inputTokens.at(size_t(pos) - 1) : -1; // not serialized
}
auto LLamaModel::computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator
{
assert(ctx.n_past >= 0);
auto pos = size_t(ctx.n_past);
if (pos > d_ptr->inputTokens.size()) {
std::ostringstream ss;
ss << "n_past=" << pos << " is past end of token cache length=" << d_ptr->inputTokens.size();
throw std::out_of_range(ss.str());
}
// find common prefix
auto cacheIt = d_ptr->inputTokens.begin();
auto inputIt = input.begin();
while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) {
++cacheIt; ++inputIt; ++pos;
++cacheIt; ++inputIt;
}
// tell the caller to ignore the tokens between [begin, inputIt)
return inputIt;
return inputIt - input.begin();
}
void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
void LLamaModel::setModelInputPosition(int32_t pos)
{
auto &inp = d_ptr->inputTokens;
assert(pos >= 0);
@@ -695,13 +690,11 @@ void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
// truncate token cache to end at the new n_past
if (pos < inp.size())
inp.resize(pos);
ctx.n_past = pos;
}
void LLamaModel::appendInputToken(PromptContext &ctx, Token tok)
void LLamaModel::appendInputToken(Token tok)
{
d_ptr->inputTokens.push_back(tok);
ctx.n_past += 1;
}
auto LLamaModel::inputTokens() const -> std::span<const Token>
@@ -729,6 +722,37 @@ int32_t LLamaModel::layerCount(std::string const &modelPath) const
return get_arch_key_u32(modelPath, "block_count");
}
// TODO(jared): reduce redundant code and operations by combining all metadata getters for unloaded
// models into a class that keeps the model file open
auto LLamaModel::chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string>
{
auto *ctx = load_gguf(modelPath);
if (!ctx)
return std::unexpected("failed to open model file");
std::expected<std::string, std::string> result;
enum gguf_type ktype;
const int kid = gguf_find_key(ctx, "tokenizer.chat_template");
if (kid == -1) {
result = std::unexpected("key not found");
goto cleanup;
}
ktype = gguf_get_kv_type(ctx, kid);
if (ktype != GGUF_TYPE_STRING) {
result = std::unexpected(
"expected key type STRING (" + std::to_string(GGUF_TYPE_STRING) + "), got " + std::to_string(ktype)
);
goto cleanup;
}
result = gguf_get_val_str(ctx, kid);
cleanup:
gguf_free(ctx);
return result;
}
#ifdef GGML_USE_VULKAN
static const char *getVulkanVendorName(uint32_t vendorID)
{

View File

@@ -11,6 +11,7 @@
#include <string>
#include <string_view>
#include <vector>
#include <unordered_map>
struct LLamaPrivate;
struct EmbModelSpec;
@@ -49,26 +50,26 @@ public:
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
int32_t contextLength() const override;
auto specialTokens() -> std::unordered_map<std::string, std::string> const override;
protected:
std::vector<Token> tokenize(std::string_view str, bool special) override;
std::vector<Token> tokenize(std::string_view str) const override;
bool isSpecialToken(Token id) const override;
std::string tokenToString(Token id) const override;
void initSampler(PromptContext &ctx) override;
void initSampler(const PromptContext &ctx) override;
Token sampleToken() const override;
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override;
void shiftContext(PromptContext &promptCtx) override;
bool evalTokens(int32_t nPast, std::span<const Token> tokens) const override;
void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override;
int32_t inputLength() const override;
void setTokenizeInputPosition(int32_t pos) override;
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator override;
void setModelInputPosition(PromptContext &ctx, int32_t pos) override;
void appendInputToken(PromptContext &ctx, Token tok) override;
int32_t computeModelInputPosition(std::span<const Token> input) const override;
void setModelInputPosition(int32_t pos) override;
void appendInputToken(Token tok) override;
std::span<const Token> inputTokens() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override;
int32_t maxContextLength(std::string const &modelPath) const override;
int32_t layerCount(std::string const &modelPath) const override;
auto chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string> override;
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,

View File

@@ -326,6 +326,12 @@ bool LLModel::Implementation::isEmbeddingModel(const std::string &modelPath)
return llama && llama->isEmbeddingModel(modelPath);
}
auto LLModel::Implementation::chatTemplate(const char *modelPath) -> std::expected<std::string, std::string>
{
auto *llama = constructGlobalLlama();
return llama ? llama->chatTemplate(modelPath) : std::unexpected("backend not available");
}
void LLModel::Implementation::setImplementationsSearchPath(const std::string& path)
{
s_implementations_search_path = path;

View File

@@ -7,7 +7,6 @@
#include <cstdlib>
#include <cstring>
#include <exception>
#include <functional>
#include <iostream>
#include <memory>
#include <optional>
@@ -22,7 +21,6 @@ static_assert(sizeof(token_t) == sizeof(LLModel::Token));
struct LLModelWrapper {
LLModel *llModel = nullptr;
LLModel::PromptContext promptContext;
~LLModelWrapper() { delete llModel; }
};
@@ -126,49 +124,44 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6
return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)});
}
void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
bool allow_context_shift,
llmodel_prompt_context *ctx,
bool special,
const char *fake_reply)
bool llmodel_prompt(llmodel_model model,
const char *prompt,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_prompt_context *ctx,
const char **error)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
auto response_func = [response_callback](int32_t token_id, const std::string &response) {
return response_callback(token_id, response.c_str());
// Copy the C prompt context
LLModel::PromptContext promptContext {
.n_predict = ctx->n_predict,
.top_k = ctx->top_k,
.top_p = ctx->top_p,
.min_p = ctx->min_p,
.temp = ctx->temp,
.n_batch = ctx->n_batch,
.repeat_penalty = ctx->repeat_penalty,
.repeat_last_n = ctx->repeat_last_n,
.contextErase = ctx->context_erase,
};
// Copy the C prompt context
wrapper->promptContext.n_past = ctx->n_past;
wrapper->promptContext.n_predict = ctx->n_predict;
wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p;
wrapper->promptContext.min_p = ctx->min_p;
wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
auto prompt_func = [prompt_callback](std::span<const LLModel::Token> token_ids, bool cached) {
return prompt_callback(token_ids.data(), token_ids.size(), cached);
};
auto response_func = [response_callback](LLModel::Token token_id, std::string_view piece) {
return response_callback(token_id, piece.data());
};
// Call the C++ prompt method
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
wrapper->promptContext, special,
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt);
try {
wrapper->llModel->prompt(prompt, prompt_func, response_func, promptContext);
} catch (std::exception const &e) {
llmodel_set_error(error, e.what());
return false;
}
// Update the rest of the C prompt context
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
ctx->min_p = wrapper->promptContext.min_p;
ctx->temp = wrapper->promptContext.temp;
ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
ctx->context_erase = wrapper->promptContext.contextErase;
return true;
}
float *llmodel_embed(
@@ -307,3 +300,21 @@ const char *llmodel_model_gpu_device_name(llmodel_model model)
const auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->gpuDeviceName();
}
int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error)
{
auto *wrapper = static_cast<const LLModelWrapper *>(model);
try {
return wrapper->llModel->countPromptTokens(prompt);
} catch (const std::exception& e) {
llmodel_set_error(error, e.what());
return -1;
}
}
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback)
{
auto *wrapper = static_cast<const LLModelWrapper *>(model);
for (auto &[name, token] : wrapper->llModel->specialTokens())
callback(name.c_str(), token.c_str());
}

View File

@@ -4,232 +4,120 @@
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iostream>
#include <iterator>
#include <optional>
#include <regex>
#include <sstream>
#include <ranges>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>
namespace ranges = std::ranges;
namespace views = std::ranges::views;
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err)
{
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))");
void LLModel::prompt(
std::string_view prompt,
const PromptCallback &promptCallback,
const ResponseCallback &responseCallback,
const PromptContext &promptCtx
) {
if (!isModelLoaded())
throw std::invalid_argument("Attempted to prompt an unloaded model.");
if (!supportsCompletion())
throw std::invalid_argument("Not a text completion model.");
if (!promptCtx.n_batch)
throw std::invalid_argument("Batch size cannot be zero.");
if (!promptCtx.n_predict)
return; // nothing requested
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex);
placeholders.clear();
placeholders.insert(placeholders.end(), it, std::sregex_iterator());
auto embd_inp = tokenize(prompt);
if (embd_inp.empty())
throw std::invalid_argument("Prompt tokenized to zero tokens.");
if (placeholders.size() > 2) {
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size());
return false;
}
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
return false;
}
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
return false;
}
return true;
if (auto res = decodePrompt(promptCallback, promptCtx, std::move(embd_inp)))
generateResponse(responseCallback, promptCtx, /*n_past*/ *res);
}
void LLModel::prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::optional<std::string_view> fakeReply)
int32_t LLModel::countPromptTokens(std::string_view prompt) const
{
if (!isModelLoaded()) {
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
return;
}
if (!supportsCompletion()) {
std::string errorMessage = "ERROR: this model does not support text completion or chat!";
responseCallback(-1, errorMessage);
std::cerr << implementation().modelType() << " " << errorMessage << "\n";
return;
}
// sanity checks
if (promptCtx.n_past > contextLength()) {
std::ostringstream ss;
ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength();
throw std::out_of_range(ss.str());
}
if (promptCtx.n_past > inputLength()) {
std::ostringstream ss;
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << inputLength();
throw std::out_of_range(ss.str());
}
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
// parse the prompt template
std::vector<std::smatch> placeholders;
{
std::string err;
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
responseCallback(-1, err);
std::cerr << err << "\n";
return;
}
}
setTokenizeInputPosition(promptCtx.n_past);
// tokenize the user prompt
std::vector<Token> embd_inp;
if (placeholders.empty()) {
// this is unusual, but well-defined
std::cerr << __func__ << ": prompt template has no placeholder\n";
embd_inp = tokenize(promptTemplate, true);
} else {
// template: beginning of user prompt
const auto &phUser = placeholders[0];
std::string userPrefix(phUser.prefix());
if (!userPrefix.empty())
embd_inp = tokenize(userPrefix, true);
// user input (shouldn't have special token processing)
auto tokens = tokenize(prompt, special);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
// template: end of user prompt + start of assistant prompt
size_t start = phUser.position() + phUser.length();
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
auto userToAsst = promptTemplate.substr(start, end - start);
if (!userToAsst.empty()) {
tokens = tokenize(userToAsst, true);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
}
}
// decode the user prompt
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, /*isResponse*/ false,
/*alwaysDecode*/ true))
return; // error
// decode the assistant's reply, either generated or spoofed
if (!fakeReply) {
generateResponse(responseCallback, allowContextShift, promptCtx);
} else {
embd_inp = tokenize(*fakeReply, false);
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true))
return; // error
}
// decode the rest of the prompt template
// template: end of assistant prompt
std::string asstSuffix;
if (placeholders.size() >= 2) {
size_t start = placeholders[1].position() + placeholders[1].length();
asstSuffix = promptTemplate.substr(start);
} else {
asstSuffix = "\n\n"; // default to a blank link, good for e.g. Alpaca
}
if (!asstSuffix.empty()) {
embd_inp = tokenize(asstSuffix, true);
decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
}
if (!isModelLoaded())
throw std::invalid_argument("Attempted to tokenize with an unloaded model.");
return int32_t(tokenize(prompt).size());
}
// returns false on error
bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp,
bool isResponse,
bool alwaysDecode) {
if ((int) embd_inp.size() > contextLength() - 4) {
// FIXME: (Adam) We should find a way to bubble these strings to the UI level to allow for
// translation
responseCallback(-1, "Your message was too long and could not be processed. Please try again with something shorter.");
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
" tokens and the context window is " << contextLength() << "!\n";
return false;
}
auto LLModel::decodePrompt(
const PromptCallback &promptCallback,
const PromptContext &promptCtx,
std::vector<Token> embd_inp
) -> std::optional<int32_t>
{
assert(!embd_inp.empty());
// FIXME(jared): There are mitigations for this situation, such as making room before
// copying the prompt context, or restoring the KV cache when we restore the prompt
// context.
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size()
<< ", n_ctx=" << contextLength() << "\n";
return false;
}
// always decode something before generating, even if cached
if (alwaysDecode && embd_inp.empty()) {
auto cache = inputTokens();
if (!promptCtx.n_past)
throw std::runtime_error("zero token prompt is not supported");
assert(!cache.empty());
embd_inp.push_back(cache.back());
promptCtx.n_past--;
}
int32_t nCtx = contextLength();
int32_t n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
// Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the
// requested n_past.
// This is used to skip unnecessary work when the prompt shares a common prefix with the previous result.
auto embd_inp_start = computeModelInputPosition(promptCtx, embd_inp);
size_t start_offset = embd_inp_start - embd_inp.begin();
int32_t nPast = computeModelInputPosition(embd_inp);
// always decode up to a full batch before generating, even if cached
if (alwaysDecode)
start_offset -= std::min(promptCtx.n_batch, int32_t(start_offset));
nPast -= std::min(n_batch, nPast);
setModelInputPosition(promptCtx, promptCtx.n_past + start_offset);
// TODO(jared): generalize this to find the smallest new_embd_inp.size() - nPast given the cache
if (!nPast && int32_t(embd_inp.size()) > nCtx) {
// no cache hit -> shift the input before even processing
// execute the callback even for skipped tokens
size_t i = 0;
for (; i < start_offset; i++) {
Token tok = embd_inp[i];
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
if (!res)
return false;
int32_t nKeep = shouldAddBOS();
auto newLength = int32_t(nCtx * (1.f - promptCtx.contextErase));
int32_t nDiscard = int32_t(embd_inp.size()) - std::max(1, std::min(nCtx, newLength));
// execute the callback even for skipped tokens. this misrepresents the position of BOS but we don't care
auto discardedTokens = embd_inp | views::drop(nKeep) | views::take(nDiscard);
if (!promptCallback(discardedTokens, true))
return std::nullopt;
// erase nDiscard tokens
embd_inp.erase(discardedTokens.begin(), discardedTokens.end());
assert(int32_t(embd_inp.size()) <= nCtx);
// check the cache again, just in case
nPast = computeModelInputPosition(embd_inp);
nPast -= std::min(n_batch, nPast);
}
setModelInputPosition(nPast);
// execute the callback even for skipped tokens
if (!promptCallback(embd_inp | views::take(nPast), true))
return std::nullopt;
// process the prompt in batches
while (i < embd_inp.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
std::span<const Token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
for (int32_t i = nPast; i < embd_inp.size();) {
auto batch_end = std::min(i + n_batch, int32_t(embd_inp.size()));
std::span batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
// Check if the context has run out...
if (promptCtx.n_past + int32_t(batch.size()) > contextLength()) {
assert(allowContextShift);
shiftContext(promptCtx);
assert(promptCtx.n_past + int32_t(batch.size()) <= contextLength());
if (nPast + int32_t(batch.size()) > nCtx) {
shiftContext(promptCtx, &nPast);
assert(nPast + int32_t(batch.size()) <= nCtx);
}
if (!evalTokens(promptCtx, batch)) {
std::cerr << implementation().modelType() << " ERROR: Failed to process prompt\n";
return false;
}
// FIXME(Adam): We should find a way to bubble these strings to the UI level to allow for translation
if (!evalTokens(nPast, batch))
throw std::runtime_error("An internal error was encountered during prompt processing.");
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t) {
Token tok = batch[t];
appendInputToken(promptCtx, tok);
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
if (!res)
return false;
for (auto &tok : batch) {
appendInputToken(tok);
nPast++;
if (!promptCallback({ &tok, 1 }, false))
return std::nullopt;
}
i = batch_end;
}
return true;
return nPast;
}
/*
@@ -251,22 +139,16 @@ static std::string::size_type stringsOverlap(const std::string &s, const std::st
return std::string::npos;
}
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx) {
void LLModel::generateResponse(
const ResponseCallback &responseCallback,
const PromptContext &promptCtx,
int32_t nPast
) {
static const char *stopSequences[] {
"### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context",
"### System", "### Instruction", "### Human", "### User", "### Response", "### Assistant", "### Context",
"<|im_start|>", "<|im_end|>", "<|endoftext|>",
};
// Don't even start if there is no room
if (!promptCtx.n_predict)
return;
if (!allowContextShift && promptCtx.n_past >= contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << contextLength()
<< "\n";
return;
}
initSampler(promptCtx);
std::string cachedResponse;
@@ -281,25 +163,20 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
cachedTokens.push_back(new_tok.value());
cachedResponse += new_piece;
auto accept = [this, &promptCtx, &new_tok, allowContextShift]() -> bool {
auto accept = [this, &promptCtx, &new_tok, &nPast] {
// Shift context if out of space
if (promptCtx.n_past >= contextLength()) {
(void)allowContextShift;
assert(allowContextShift);
shiftContext(promptCtx);
assert(promptCtx.n_past < contextLength());
if (nPast >= contextLength()) {
shiftContext(promptCtx, &nPast);
assert(nPast < contextLength());
}
// Accept the token
Token tok = std::exchange(new_tok, std::nullopt).value();
if (!evalTokens(promptCtx, { &tok, 1 })) {
// TODO(jared): raise an exception
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
return false;
}
if (!evalTokens(nPast, { &tok, 1 }))
throw std::runtime_error("An internal error was encountered during response generation.");
appendInputToken(promptCtx, tok);
return true;
appendInputToken(tok);
nPast++;
};
// Check for EOS
@@ -336,13 +213,6 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
lengthLimit = cachedResponse.size() - new_piece.size();
}
// Optionally stop if the context will run out
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx="
<< contextLength() << "\n";
stop = true;
}
// Empty the cache, up to the length limit
std::string::size_type responseLength = 0;
while (!cachedTokens.empty()) {
@@ -359,8 +229,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
cachedResponse.erase(cachedResponse.begin(), cachedResponse.begin() + piece.size());
// Accept the token, if needed (not cached)
if (cachedTokens.empty() && new_tok && !accept())
return;
if (cachedTokens.empty() && new_tok)
accept();
// Send the token
if (!responseCallback(tok, piece) || ++n_predicted >= promptCtx.n_predict) {
@@ -379,8 +249,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
assert(!cachedTokens.empty() && cachedTokens.back() == new_tok);
if (stop) {
cachedTokens.pop_back();
} else if (!accept()) {
return;
} else {
accept();
}
}
}
@@ -396,8 +266,6 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
auto discard_start = inp.end() - cachedTokens.size();
assert(std::equal(discard_start, inp.end(), cachedTokens.begin()));
#endif
promptCtx.n_past -= cachedTokens.size();
}
void LLModel::embed(