mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-10 04:49:07 +00:00
Remove binary state from high-level API and use Jinja templates (#3147)
Signed-off-by: Jared Van Bortel <jared@nomic.ai> Signed-off-by: Adam Treat <treat.adam@gmail.com> Co-authored-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <expected>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <span>
|
||||
@@ -24,6 +25,10 @@ using namespace std::string_literals;
|
||||
class LLModel {
|
||||
public:
|
||||
using Token = int32_t;
|
||||
using PromptCallback = std::function<bool(std::span<const Token> batch, bool cached)>;
|
||||
using ResponseCallback = std::function<bool(Token token, std::string_view piece)>;
|
||||
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
|
||||
using ProgressCallback = std::function<bool(float progress)>;
|
||||
|
||||
class BadArchError: public std::runtime_error {
|
||||
public:
|
||||
@@ -101,6 +106,7 @@ public:
|
||||
static int32_t maxContextLength(const std::string &modelPath);
|
||||
static int32_t layerCount(const std::string &modelPath);
|
||||
static bool isEmbeddingModel(const std::string &modelPath);
|
||||
static auto chatTemplate(const char *modelPath) -> std::expected<std::string, std::string>;
|
||||
static void setImplementationsSearchPath(const std::string &path);
|
||||
static const std::string &implementationsSearchPath();
|
||||
static bool hasSupportedCPU();
|
||||
@@ -124,7 +130,6 @@ public:
|
||||
};
|
||||
|
||||
struct PromptContext {
|
||||
int32_t n_past = 0; // number of tokens in past conversation
|
||||
int32_t n_predict = 200;
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
@@ -136,8 +141,6 @@ public:
|
||||
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
|
||||
};
|
||||
|
||||
using ProgressCallback = std::function<bool(float progress)>;
|
||||
|
||||
explicit LLModel() {}
|
||||
virtual ~LLModel() {}
|
||||
|
||||
@@ -154,16 +157,12 @@ public:
|
||||
|
||||
// This method requires the model to return true from supportsCompletion otherwise it will throw
|
||||
// an error
|
||||
virtual void prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &ctx,
|
||||
bool special = false,
|
||||
std::optional<std::string_view> fakeReply = {});
|
||||
virtual void prompt(std::string_view prompt,
|
||||
const PromptCallback &promptCallback,
|
||||
const ResponseCallback &responseCallback,
|
||||
const PromptContext &ctx);
|
||||
|
||||
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
|
||||
virtual int32_t countPromptTokens(std::string_view prompt) const;
|
||||
|
||||
virtual size_t embeddingSize() const {
|
||||
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
|
||||
@@ -209,23 +208,22 @@ public:
|
||||
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }
|
||||
|
||||
virtual int32_t contextLength() const = 0;
|
||||
virtual auto specialTokens() -> std::unordered_map<std::string, std::string> const = 0;
|
||||
|
||||
protected:
|
||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||
// 'prompt' above calls these functions
|
||||
virtual std::vector<Token> tokenize(std::string_view str, bool special = false) = 0;
|
||||
virtual std::vector<Token> tokenize(std::string_view str) const = 0;
|
||||
virtual bool isSpecialToken(Token id) const = 0;
|
||||
virtual std::string tokenToString(Token id) const = 0;
|
||||
virtual void initSampler(PromptContext &ctx) = 0;
|
||||
virtual void initSampler(const PromptContext &ctx) = 0;
|
||||
virtual Token sampleToken() const = 0;
|
||||
virtual bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const = 0;
|
||||
virtual void shiftContext(PromptContext &promptCtx) = 0;
|
||||
virtual bool evalTokens(int32_t nPast, std::span<const Token> tokens) const = 0;
|
||||
virtual void shiftContext(const PromptContext &promptCtx, int32_t *nPast) = 0;
|
||||
virtual int32_t inputLength() const = 0;
|
||||
virtual void setTokenizeInputPosition(int32_t pos) = 0;
|
||||
virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
|
||||
-> std::vector<Token>::const_iterator = 0;
|
||||
virtual void setModelInputPosition(PromptContext &ctx, int32_t pos) = 0;
|
||||
virtual void appendInputToken(PromptContext &ctx, Token tok) = 0;
|
||||
virtual int32_t computeModelInputPosition(std::span<const Token> input) const = 0;
|
||||
virtual void setModelInputPosition(int32_t pos) = 0;
|
||||
virtual void appendInputToken(Token tok) = 0;
|
||||
virtual std::span<const Token> inputTokens() const = 0;
|
||||
virtual const std::vector<Token> &endTokens() const = 0;
|
||||
virtual bool shouldAddBOS() const = 0;
|
||||
@@ -242,6 +240,12 @@ protected:
|
||||
return -1;
|
||||
}
|
||||
|
||||
virtual auto chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string>
|
||||
{
|
||||
(void)modelPath;
|
||||
return std::unexpected("not implemented");
|
||||
}
|
||||
|
||||
const Implementation *m_implementation = nullptr;
|
||||
|
||||
ProgressCallback m_progressCallback;
|
||||
@@ -253,19 +257,15 @@ protected:
|
||||
return true;
|
||||
}
|
||||
|
||||
bool decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp,
|
||||
bool isResponse = false,
|
||||
bool alwaysDecode = false);
|
||||
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx);
|
||||
|
||||
protected:
|
||||
Token m_tokenize_last_token = -1; // not serialized
|
||||
// prefill context with prompt
|
||||
auto decodePrompt(const PromptCallback &promptCallback,
|
||||
const PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp)
|
||||
-> std::optional<int32_t>;
|
||||
// generate a response
|
||||
void generateResponse(const ResponseCallback &responseCallback,
|
||||
const PromptContext &promptCtx,
|
||||
int32_t nPast);
|
||||
|
||||
friend class LLMImplementation;
|
||||
};
|
||||
|
@@ -35,16 +35,15 @@ typedef int32_t token_t;
|
||||
* behavior.
|
||||
*/
|
||||
struct llmodel_prompt_context {
|
||||
int32_t n_past; // number of tokens in past conversation
|
||||
int32_t n_predict; // number of tokens to predict
|
||||
int32_t top_k; // top k logits to sample from
|
||||
float top_p; // nucleus sampling probability threshold
|
||||
float min_p; // Min P sampling
|
||||
float temp; // temperature to adjust model's output distribution
|
||||
float top_p; // nucleus sampling probability threshold
|
||||
float min_p; // Min P sampling
|
||||
float temp; // temperature to adjust model's output distribution
|
||||
int32_t n_batch; // number of predictions to generate in parallel
|
||||
float repeat_penalty; // penalty factor for repeated tokens
|
||||
float repeat_penalty; // penalty factor for repeated tokens
|
||||
int32_t repeat_last_n; // last n tokens to penalize
|
||||
float context_erase; // percent of context to erase if we exceed the context window
|
||||
float context_erase; // percent of context to erase if we exceed the context window
|
||||
};
|
||||
|
||||
struct llmodel_gpu_device {
|
||||
@@ -63,10 +62,12 @@ typedef struct llmodel_gpu_device llmodel_gpu_device;
|
||||
|
||||
/**
|
||||
* Callback type for prompt processing.
|
||||
* @param token_id The token id of the prompt.
|
||||
* @param token_ids An array of token ids of the prompt.
|
||||
* @param n_token_ids The number of tokens in the array.
|
||||
* @param cached Whether the tokens were already in cache.
|
||||
* @return a bool indicating whether the model should keep processing.
|
||||
*/
|
||||
typedef bool (*llmodel_prompt_callback)(int32_t token_id);
|
||||
typedef bool (*llmodel_prompt_callback)(const token_t *token_ids, size_t n_token_ids, bool cached);
|
||||
|
||||
/**
|
||||
* Callback type for response.
|
||||
@@ -74,7 +75,7 @@ typedef bool (*llmodel_prompt_callback)(int32_t token_id);
|
||||
* @param response The response string. NOTE: a token_id of -1 indicates the string is an error string.
|
||||
* @return a bool indicating whether the model should keep generating.
|
||||
*/
|
||||
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
|
||||
typedef bool (*llmodel_response_callback)(token_t token_id, const char *response);
|
||||
|
||||
/**
|
||||
* Embedding cancellation callback for use with llmodel_embed.
|
||||
@@ -85,6 +86,8 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response
|
||||
*/
|
||||
typedef bool (*llmodel_emb_cancel_callback)(unsigned *batch_sizes, unsigned n_batch, const char *backend);
|
||||
|
||||
typedef void (*llmodel_special_token_callback)(const char *name, const char *token);
|
||||
|
||||
/**
|
||||
* Create a llmodel instance.
|
||||
* Recognises correct model type from file at model_path
|
||||
@@ -183,22 +186,17 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6
|
||||
* Generate a response using the model.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param prompt A string representing the input prompt.
|
||||
* @param prompt_template A string representing the input prompt template.
|
||||
* @param prompt_callback A callback function for handling the processing of prompt.
|
||||
* @param response_callback A callback function for handling the generated response.
|
||||
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
|
||||
* @param special True if special tokens in the prompt should be processed, false otherwise.
|
||||
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
|
||||
* @param ctx A pointer to the llmodel_prompt_context structure.
|
||||
* @param error A pointer to a string; will only be set on error.
|
||||
*/
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
const char *prompt_template,
|
||||
llmodel_prompt_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
bool allow_context_shift,
|
||||
llmodel_prompt_context *ctx,
|
||||
bool special,
|
||||
const char *fake_reply);
|
||||
bool llmodel_prompt(llmodel_model model,
|
||||
const char *prompt,
|
||||
llmodel_prompt_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
llmodel_prompt_context *ctx,
|
||||
const char **error);
|
||||
|
||||
/**
|
||||
* Generate an embedding using the model.
|
||||
@@ -310,6 +308,10 @@ const char *llmodel_model_backend_name(llmodel_model model);
|
||||
*/
|
||||
const char *llmodel_model_gpu_device_name(llmodel_model model);
|
||||
|
||||
int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error);
|
||||
|
||||
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@@ -202,7 +202,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const
|
||||
if (keyidx != -1) {
|
||||
value = gguf_get_val_u32(ctx, keyidx);
|
||||
} else {
|
||||
std::cerr << __func__ << ": " << key << "not found in " << modelPath << "\n";
|
||||
std::cerr << __func__ << ": " << key << " not found in " << modelPath << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -518,18 +518,13 @@ size_t LLamaModel::restoreState(std::span<const uint8_t> state, std::span<const
|
||||
return bytesRead;
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str, bool special)
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str) const
|
||||
{
|
||||
bool atStart = m_tokenize_last_token == -1;
|
||||
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
|
||||
std::vector<LLModel::Token> fres(str.length() + 4);
|
||||
int32_t fres_len = llama_tokenize_gpt4all(
|
||||
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
|
||||
/*parse_special*/ special, /*insert_space*/ insertSpace
|
||||
int32_t fres_len = llama_tokenize(
|
||||
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ true, /*parse_special*/ true
|
||||
);
|
||||
fres.resize(fres_len);
|
||||
if (fres_len)
|
||||
m_tokenize_last_token = fres.back();
|
||||
return fres;
|
||||
}
|
||||
|
||||
@@ -555,7 +550,7 @@ std::string LLamaModel::tokenToString(Token id) const
|
||||
return std::string(result.data(), result.size());
|
||||
}
|
||||
|
||||
void LLamaModel::initSampler(PromptContext &promptCtx)
|
||||
void LLamaModel::initSampler(const PromptContext &promptCtx)
|
||||
{
|
||||
auto *model = d_ptr->model;
|
||||
auto *chain = d_ptr->sampler_chain;
|
||||
@@ -601,9 +596,11 @@ LLModel::Token LLamaModel::sampleToken() const
|
||||
return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1);
|
||||
}
|
||||
|
||||
bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) const
|
||||
bool LLamaModel::evalTokens(int32_t nPast, std::span<const Token> tokens) const
|
||||
{
|
||||
llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1);
|
||||
assert(!tokens.empty());
|
||||
|
||||
llama_kv_cache_seq_rm(d_ptr->ctx, 0, nPast, -1);
|
||||
|
||||
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
||||
|
||||
@@ -611,7 +608,7 @@ bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) c
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
batch.token [i] = tokens[i];
|
||||
batch.pos [i] = ctx.n_past + i;
|
||||
batch.pos [i] = nPast + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i][0] = 0;
|
||||
batch.logits [i] = false;
|
||||
@@ -625,13 +622,13 @@ bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) c
|
||||
return res == 0;
|
||||
}
|
||||
|
||||
void LLamaModel::shiftContext(PromptContext &promptCtx)
|
||||
void LLamaModel::shiftContext(const PromptContext &promptCtx, int32_t *nPast)
|
||||
{
|
||||
// infinite text generation via context shifting
|
||||
|
||||
// erase up to n_ctx*contextErase tokens
|
||||
int n_keep = shouldAddBOS();
|
||||
int n_past = promptCtx.n_past;
|
||||
int n_past = *nPast;
|
||||
int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase));
|
||||
|
||||
assert(n_discard > 0);
|
||||
@@ -647,7 +644,7 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
|
||||
|
||||
auto &inp = d_ptr->inputTokens;
|
||||
inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard);
|
||||
promptCtx.n_past = inp.size();
|
||||
*nPast = inp.size();
|
||||
}
|
||||
|
||||
int32_t LLamaModel::contextLength() const
|
||||
@@ -655,39 +652,37 @@ int32_t LLamaModel::contextLength() const
|
||||
return llama_n_ctx(d_ptr->ctx);
|
||||
}
|
||||
|
||||
auto LLamaModel::specialTokens() -> std::unordered_map<std::string, std::string> const
|
||||
{
|
||||
if (!d_ptr->model)
|
||||
throw std::logic_error("model not loaded");
|
||||
|
||||
std::unordered_map<std::string, std::string> tokens;
|
||||
if (auto id = llama_token_bos(d_ptr->model); id != LLAMA_TOKEN_NULL)
|
||||
tokens.emplace("bos_token", tokenToString(id));
|
||||
if (auto id = llama_token_eos(d_ptr->model); id != LLAMA_TOKEN_NULL)
|
||||
tokens.emplace("eos_token", tokenToString(id));
|
||||
return tokens;
|
||||
}
|
||||
|
||||
int32_t LLamaModel::inputLength() const
|
||||
{
|
||||
return d_ptr->inputTokens.size();
|
||||
}
|
||||
|
||||
void LLamaModel::setTokenizeInputPosition(int32_t pos)
|
||||
int32_t LLamaModel::computeModelInputPosition(std::span<const Token> input) const
|
||||
{
|
||||
assert(pos >= 0);
|
||||
m_tokenize_last_token = pos ? d_ptr->inputTokens.at(size_t(pos) - 1) : -1; // not serialized
|
||||
}
|
||||
|
||||
auto LLamaModel::computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
|
||||
-> std::vector<Token>::const_iterator
|
||||
{
|
||||
assert(ctx.n_past >= 0);
|
||||
auto pos = size_t(ctx.n_past);
|
||||
if (pos > d_ptr->inputTokens.size()) {
|
||||
std::ostringstream ss;
|
||||
ss << "n_past=" << pos << " is past end of token cache length=" << d_ptr->inputTokens.size();
|
||||
throw std::out_of_range(ss.str());
|
||||
}
|
||||
|
||||
// find common prefix
|
||||
auto cacheIt = d_ptr->inputTokens.begin();
|
||||
auto inputIt = input.begin();
|
||||
while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) {
|
||||
++cacheIt; ++inputIt; ++pos;
|
||||
++cacheIt; ++inputIt;
|
||||
}
|
||||
// tell the caller to ignore the tokens between [begin, inputIt)
|
||||
return inputIt;
|
||||
return inputIt - input.begin();
|
||||
}
|
||||
|
||||
void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
|
||||
void LLamaModel::setModelInputPosition(int32_t pos)
|
||||
{
|
||||
auto &inp = d_ptr->inputTokens;
|
||||
assert(pos >= 0);
|
||||
@@ -695,13 +690,11 @@ void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
|
||||
// truncate token cache to end at the new n_past
|
||||
if (pos < inp.size())
|
||||
inp.resize(pos);
|
||||
ctx.n_past = pos;
|
||||
}
|
||||
|
||||
void LLamaModel::appendInputToken(PromptContext &ctx, Token tok)
|
||||
void LLamaModel::appendInputToken(Token tok)
|
||||
{
|
||||
d_ptr->inputTokens.push_back(tok);
|
||||
ctx.n_past += 1;
|
||||
}
|
||||
|
||||
auto LLamaModel::inputTokens() const -> std::span<const Token>
|
||||
@@ -729,6 +722,37 @@ int32_t LLamaModel::layerCount(std::string const &modelPath) const
|
||||
return get_arch_key_u32(modelPath, "block_count");
|
||||
}
|
||||
|
||||
// TODO(jared): reduce redundant code and operations by combining all metadata getters for unloaded
|
||||
// models into a class that keeps the model file open
|
||||
auto LLamaModel::chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string>
|
||||
{
|
||||
auto *ctx = load_gguf(modelPath);
|
||||
if (!ctx)
|
||||
return std::unexpected("failed to open model file");
|
||||
|
||||
std::expected<std::string, std::string> result;
|
||||
enum gguf_type ktype;
|
||||
const int kid = gguf_find_key(ctx, "tokenizer.chat_template");
|
||||
if (kid == -1) {
|
||||
result = std::unexpected("key not found");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
ktype = gguf_get_kv_type(ctx, kid);
|
||||
if (ktype != GGUF_TYPE_STRING) {
|
||||
result = std::unexpected(
|
||||
"expected key type STRING (" + std::to_string(GGUF_TYPE_STRING) + "), got " + std::to_string(ktype)
|
||||
);
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
result = gguf_get_val_str(ctx, kid);
|
||||
|
||||
cleanup:
|
||||
gguf_free(ctx);
|
||||
return result;
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
static const char *getVulkanVendorName(uint32_t vendorID)
|
||||
{
|
||||
|
@@ -11,6 +11,7 @@
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
struct LLamaPrivate;
|
||||
struct EmbModelSpec;
|
||||
@@ -49,26 +50,26 @@ public:
|
||||
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
|
||||
|
||||
int32_t contextLength() const override;
|
||||
auto specialTokens() -> std::unordered_map<std::string, std::string> const override;
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(std::string_view str, bool special) override;
|
||||
std::vector<Token> tokenize(std::string_view str) const override;
|
||||
bool isSpecialToken(Token id) const override;
|
||||
std::string tokenToString(Token id) const override;
|
||||
void initSampler(PromptContext &ctx) override;
|
||||
void initSampler(const PromptContext &ctx) override;
|
||||
Token sampleToken() const override;
|
||||
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override;
|
||||
void shiftContext(PromptContext &promptCtx) override;
|
||||
bool evalTokens(int32_t nPast, std::span<const Token> tokens) const override;
|
||||
void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override;
|
||||
int32_t inputLength() const override;
|
||||
void setTokenizeInputPosition(int32_t pos) override;
|
||||
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
|
||||
-> std::vector<Token>::const_iterator override;
|
||||
void setModelInputPosition(PromptContext &ctx, int32_t pos) override;
|
||||
void appendInputToken(PromptContext &ctx, Token tok) override;
|
||||
int32_t computeModelInputPosition(std::span<const Token> input) const override;
|
||||
void setModelInputPosition(int32_t pos) override;
|
||||
void appendInputToken(Token tok) override;
|
||||
std::span<const Token> inputTokens() const override;
|
||||
const std::vector<Token> &endTokens() const override;
|
||||
bool shouldAddBOS() const override;
|
||||
int32_t maxContextLength(std::string const &modelPath) const override;
|
||||
int32_t layerCount(std::string const &modelPath) const override;
|
||||
auto chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string> override;
|
||||
|
||||
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
|
||||
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,
|
||||
|
@@ -326,6 +326,12 @@ bool LLModel::Implementation::isEmbeddingModel(const std::string &modelPath)
|
||||
return llama && llama->isEmbeddingModel(modelPath);
|
||||
}
|
||||
|
||||
auto LLModel::Implementation::chatTemplate(const char *modelPath) -> std::expected<std::string, std::string>
|
||||
{
|
||||
auto *llama = constructGlobalLlama();
|
||||
return llama ? llama->chatTemplate(modelPath) : std::unexpected("backend not available");
|
||||
}
|
||||
|
||||
void LLModel::Implementation::setImplementationsSearchPath(const std::string& path)
|
||||
{
|
||||
s_implementations_search_path = path;
|
||||
|
@@ -7,7 +7,6 @@
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
@@ -22,7 +21,6 @@ static_assert(sizeof(token_t) == sizeof(LLModel::Token));
|
||||
|
||||
struct LLModelWrapper {
|
||||
LLModel *llModel = nullptr;
|
||||
LLModel::PromptContext promptContext;
|
||||
~LLModelWrapper() { delete llModel; }
|
||||
};
|
||||
|
||||
@@ -126,49 +124,44 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6
|
||||
return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)});
|
||||
}
|
||||
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
const char *prompt_template,
|
||||
llmodel_prompt_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
bool allow_context_shift,
|
||||
llmodel_prompt_context *ctx,
|
||||
bool special,
|
||||
const char *fake_reply)
|
||||
bool llmodel_prompt(llmodel_model model,
|
||||
const char *prompt,
|
||||
llmodel_prompt_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
llmodel_prompt_context *ctx,
|
||||
const char **error)
|
||||
{
|
||||
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
||||
|
||||
auto response_func = [response_callback](int32_t token_id, const std::string &response) {
|
||||
return response_callback(token_id, response.c_str());
|
||||
// Copy the C prompt context
|
||||
LLModel::PromptContext promptContext {
|
||||
.n_predict = ctx->n_predict,
|
||||
.top_k = ctx->top_k,
|
||||
.top_p = ctx->top_p,
|
||||
.min_p = ctx->min_p,
|
||||
.temp = ctx->temp,
|
||||
.n_batch = ctx->n_batch,
|
||||
.repeat_penalty = ctx->repeat_penalty,
|
||||
.repeat_last_n = ctx->repeat_last_n,
|
||||
.contextErase = ctx->context_erase,
|
||||
};
|
||||
|
||||
// Copy the C prompt context
|
||||
wrapper->promptContext.n_past = ctx->n_past;
|
||||
wrapper->promptContext.n_predict = ctx->n_predict;
|
||||
wrapper->promptContext.top_k = ctx->top_k;
|
||||
wrapper->promptContext.top_p = ctx->top_p;
|
||||
wrapper->promptContext.min_p = ctx->min_p;
|
||||
wrapper->promptContext.temp = ctx->temp;
|
||||
wrapper->promptContext.n_batch = ctx->n_batch;
|
||||
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
|
||||
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
|
||||
wrapper->promptContext.contextErase = ctx->context_erase;
|
||||
auto prompt_func = [prompt_callback](std::span<const LLModel::Token> token_ids, bool cached) {
|
||||
return prompt_callback(token_ids.data(), token_ids.size(), cached);
|
||||
};
|
||||
auto response_func = [response_callback](LLModel::Token token_id, std::string_view piece) {
|
||||
return response_callback(token_id, piece.data());
|
||||
};
|
||||
|
||||
// Call the C++ prompt method
|
||||
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
|
||||
wrapper->promptContext, special,
|
||||
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt);
|
||||
try {
|
||||
wrapper->llModel->prompt(prompt, prompt_func, response_func, promptContext);
|
||||
} catch (std::exception const &e) {
|
||||
llmodel_set_error(error, e.what());
|
||||
return false;
|
||||
}
|
||||
|
||||
// Update the rest of the C prompt context
|
||||
ctx->n_past = wrapper->promptContext.n_past;
|
||||
ctx->n_predict = wrapper->promptContext.n_predict;
|
||||
ctx->top_k = wrapper->promptContext.top_k;
|
||||
ctx->top_p = wrapper->promptContext.top_p;
|
||||
ctx->min_p = wrapper->promptContext.min_p;
|
||||
ctx->temp = wrapper->promptContext.temp;
|
||||
ctx->n_batch = wrapper->promptContext.n_batch;
|
||||
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
|
||||
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
|
||||
ctx->context_erase = wrapper->promptContext.contextErase;
|
||||
return true;
|
||||
}
|
||||
|
||||
float *llmodel_embed(
|
||||
@@ -307,3 +300,21 @@ const char *llmodel_model_gpu_device_name(llmodel_model model)
|
||||
const auto *wrapper = static_cast<LLModelWrapper *>(model);
|
||||
return wrapper->llModel->gpuDeviceName();
|
||||
}
|
||||
|
||||
int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error)
|
||||
{
|
||||
auto *wrapper = static_cast<const LLModelWrapper *>(model);
|
||||
try {
|
||||
return wrapper->llModel->countPromptTokens(prompt);
|
||||
} catch (const std::exception& e) {
|
||||
llmodel_set_error(error, e.what());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback)
|
||||
{
|
||||
auto *wrapper = static_cast<const LLModelWrapper *>(model);
|
||||
for (auto &[name, token] : wrapper->llModel->specialTokens())
|
||||
callback(name.c_str(), token.c_str());
|
||||
}
|
||||
|
@@ -4,232 +4,120 @@
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <optional>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <ranges>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
namespace ranges = std::ranges;
|
||||
namespace views = std::ranges::views;
|
||||
|
||||
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err)
|
||||
{
|
||||
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))");
|
||||
void LLModel::prompt(
|
||||
std::string_view prompt,
|
||||
const PromptCallback &promptCallback,
|
||||
const ResponseCallback &responseCallback,
|
||||
const PromptContext &promptCtx
|
||||
) {
|
||||
if (!isModelLoaded())
|
||||
throw std::invalid_argument("Attempted to prompt an unloaded model.");
|
||||
if (!supportsCompletion())
|
||||
throw std::invalid_argument("Not a text completion model.");
|
||||
if (!promptCtx.n_batch)
|
||||
throw std::invalid_argument("Batch size cannot be zero.");
|
||||
if (!promptCtx.n_predict)
|
||||
return; // nothing requested
|
||||
|
||||
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex);
|
||||
placeholders.clear();
|
||||
placeholders.insert(placeholders.end(), it, std::sregex_iterator());
|
||||
auto embd_inp = tokenize(prompt);
|
||||
if (embd_inp.empty())
|
||||
throw std::invalid_argument("Prompt tokenized to zero tokens.");
|
||||
|
||||
if (placeholders.size() > 2) {
|
||||
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size());
|
||||
return false;
|
||||
}
|
||||
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
|
||||
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
|
||||
return false;
|
||||
}
|
||||
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
|
||||
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
if (auto res = decodePrompt(promptCallback, promptCtx, std::move(embd_inp)))
|
||||
generateResponse(responseCallback, promptCtx, /*n_past*/ *res);
|
||||
}
|
||||
|
||||
void LLModel::prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
bool special,
|
||||
std::optional<std::string_view> fakeReply)
|
||||
int32_t LLModel::countPromptTokens(std::string_view prompt) const
|
||||
{
|
||||
if (!isModelLoaded()) {
|
||||
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
|
||||
return;
|
||||
}
|
||||
|
||||
if (!supportsCompletion()) {
|
||||
std::string errorMessage = "ERROR: this model does not support text completion or chat!";
|
||||
responseCallback(-1, errorMessage);
|
||||
std::cerr << implementation().modelType() << " " << errorMessage << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// sanity checks
|
||||
if (promptCtx.n_past > contextLength()) {
|
||||
std::ostringstream ss;
|
||||
ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength();
|
||||
throw std::out_of_range(ss.str());
|
||||
}
|
||||
if (promptCtx.n_past > inputLength()) {
|
||||
std::ostringstream ss;
|
||||
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << inputLength();
|
||||
throw std::out_of_range(ss.str());
|
||||
}
|
||||
|
||||
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
|
||||
|
||||
// parse the prompt template
|
||||
std::vector<std::smatch> placeholders;
|
||||
{
|
||||
std::string err;
|
||||
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
|
||||
responseCallback(-1, err);
|
||||
std::cerr << err << "\n";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
setTokenizeInputPosition(promptCtx.n_past);
|
||||
|
||||
// tokenize the user prompt
|
||||
std::vector<Token> embd_inp;
|
||||
if (placeholders.empty()) {
|
||||
// this is unusual, but well-defined
|
||||
std::cerr << __func__ << ": prompt template has no placeholder\n";
|
||||
embd_inp = tokenize(promptTemplate, true);
|
||||
} else {
|
||||
// template: beginning of user prompt
|
||||
const auto &phUser = placeholders[0];
|
||||
std::string userPrefix(phUser.prefix());
|
||||
if (!userPrefix.empty())
|
||||
embd_inp = tokenize(userPrefix, true);
|
||||
|
||||
// user input (shouldn't have special token processing)
|
||||
auto tokens = tokenize(prompt, special);
|
||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
||||
|
||||
// template: end of user prompt + start of assistant prompt
|
||||
size_t start = phUser.position() + phUser.length();
|
||||
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
|
||||
auto userToAsst = promptTemplate.substr(start, end - start);
|
||||
if (!userToAsst.empty()) {
|
||||
tokens = tokenize(userToAsst, true);
|
||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
||||
}
|
||||
}
|
||||
|
||||
// decode the user prompt
|
||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, /*isResponse*/ false,
|
||||
/*alwaysDecode*/ true))
|
||||
return; // error
|
||||
|
||||
// decode the assistant's reply, either generated or spoofed
|
||||
if (!fakeReply) {
|
||||
generateResponse(responseCallback, allowContextShift, promptCtx);
|
||||
} else {
|
||||
embd_inp = tokenize(*fakeReply, false);
|
||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true))
|
||||
return; // error
|
||||
}
|
||||
|
||||
// decode the rest of the prompt template
|
||||
// template: end of assistant prompt
|
||||
std::string asstSuffix;
|
||||
if (placeholders.size() >= 2) {
|
||||
size_t start = placeholders[1].position() + placeholders[1].length();
|
||||
asstSuffix = promptTemplate.substr(start);
|
||||
} else {
|
||||
asstSuffix = "\n\n"; // default to a blank link, good for e.g. Alpaca
|
||||
}
|
||||
if (!asstSuffix.empty()) {
|
||||
embd_inp = tokenize(asstSuffix, true);
|
||||
decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
|
||||
}
|
||||
if (!isModelLoaded())
|
||||
throw std::invalid_argument("Attempted to tokenize with an unloaded model.");
|
||||
return int32_t(tokenize(prompt).size());
|
||||
}
|
||||
|
||||
// returns false on error
|
||||
bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp,
|
||||
bool isResponse,
|
||||
bool alwaysDecode) {
|
||||
if ((int) embd_inp.size() > contextLength() - 4) {
|
||||
// FIXME: (Adam) We should find a way to bubble these strings to the UI level to allow for
|
||||
// translation
|
||||
responseCallback(-1, "Your message was too long and could not be processed. Please try again with something shorter.");
|
||||
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
|
||||
" tokens and the context window is " << contextLength() << "!\n";
|
||||
return false;
|
||||
}
|
||||
auto LLModel::decodePrompt(
|
||||
const PromptCallback &promptCallback,
|
||||
const PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp
|
||||
) -> std::optional<int32_t>
|
||||
{
|
||||
assert(!embd_inp.empty());
|
||||
|
||||
// FIXME(jared): There are mitigations for this situation, such as making room before
|
||||
// copying the prompt context, or restoring the KV cache when we restore the prompt
|
||||
// context.
|
||||
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > contextLength()) {
|
||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size()
|
||||
<< ", n_ctx=" << contextLength() << "\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
// always decode something before generating, even if cached
|
||||
if (alwaysDecode && embd_inp.empty()) {
|
||||
auto cache = inputTokens();
|
||||
if (!promptCtx.n_past)
|
||||
throw std::runtime_error("zero token prompt is not supported");
|
||||
assert(!cache.empty());
|
||||
embd_inp.push_back(cache.back());
|
||||
promptCtx.n_past--;
|
||||
}
|
||||
int32_t nCtx = contextLength();
|
||||
int32_t n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
|
||||
|
||||
// Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the
|
||||
// requested n_past.
|
||||
// This is used to skip unnecessary work when the prompt shares a common prefix with the previous result.
|
||||
auto embd_inp_start = computeModelInputPosition(promptCtx, embd_inp);
|
||||
size_t start_offset = embd_inp_start - embd_inp.begin();
|
||||
int32_t nPast = computeModelInputPosition(embd_inp);
|
||||
|
||||
// always decode up to a full batch before generating, even if cached
|
||||
if (alwaysDecode)
|
||||
start_offset -= std::min(promptCtx.n_batch, int32_t(start_offset));
|
||||
nPast -= std::min(n_batch, nPast);
|
||||
|
||||
setModelInputPosition(promptCtx, promptCtx.n_past + start_offset);
|
||||
// TODO(jared): generalize this to find the smallest new_embd_inp.size() - nPast given the cache
|
||||
if (!nPast && int32_t(embd_inp.size()) > nCtx) {
|
||||
// no cache hit -> shift the input before even processing
|
||||
|
||||
// execute the callback even for skipped tokens
|
||||
size_t i = 0;
|
||||
for (; i < start_offset; i++) {
|
||||
Token tok = embd_inp[i];
|
||||
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
|
||||
if (!res)
|
||||
return false;
|
||||
int32_t nKeep = shouldAddBOS();
|
||||
auto newLength = int32_t(nCtx * (1.f - promptCtx.contextErase));
|
||||
int32_t nDiscard = int32_t(embd_inp.size()) - std::max(1, std::min(nCtx, newLength));
|
||||
|
||||
// execute the callback even for skipped tokens. this misrepresents the position of BOS but we don't care
|
||||
auto discardedTokens = embd_inp | views::drop(nKeep) | views::take(nDiscard);
|
||||
if (!promptCallback(discardedTokens, true))
|
||||
return std::nullopt;
|
||||
|
||||
// erase nDiscard tokens
|
||||
embd_inp.erase(discardedTokens.begin(), discardedTokens.end());
|
||||
assert(int32_t(embd_inp.size()) <= nCtx);
|
||||
|
||||
// check the cache again, just in case
|
||||
nPast = computeModelInputPosition(embd_inp);
|
||||
nPast -= std::min(n_batch, nPast);
|
||||
}
|
||||
|
||||
setModelInputPosition(nPast);
|
||||
|
||||
// execute the callback even for skipped tokens
|
||||
if (!promptCallback(embd_inp | views::take(nPast), true))
|
||||
return std::nullopt;
|
||||
|
||||
// process the prompt in batches
|
||||
while (i < embd_inp.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
|
||||
std::span<const Token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||
for (int32_t i = nPast; i < embd_inp.size();) {
|
||||
auto batch_end = std::min(i + n_batch, int32_t(embd_inp.size()));
|
||||
std::span batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + int32_t(batch.size()) > contextLength()) {
|
||||
assert(allowContextShift);
|
||||
shiftContext(promptCtx);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= contextLength());
|
||||
if (nPast + int32_t(batch.size()) > nCtx) {
|
||||
shiftContext(promptCtx, &nPast);
|
||||
assert(nPast + int32_t(batch.size()) <= nCtx);
|
||||
}
|
||||
|
||||
if (!evalTokens(promptCtx, batch)) {
|
||||
std::cerr << implementation().modelType() << " ERROR: Failed to process prompt\n";
|
||||
return false;
|
||||
}
|
||||
// FIXME(Adam): We should find a way to bubble these strings to the UI level to allow for translation
|
||||
if (!evalTokens(nPast, batch))
|
||||
throw std::runtime_error("An internal error was encountered during prompt processing.");
|
||||
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t) {
|
||||
Token tok = batch[t];
|
||||
appendInputToken(promptCtx, tok);
|
||||
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
|
||||
if (!res)
|
||||
return false;
|
||||
for (auto &tok : batch) {
|
||||
appendInputToken(tok);
|
||||
nPast++;
|
||||
if (!promptCallback({ &tok, 1 }, false))
|
||||
return std::nullopt;
|
||||
}
|
||||
i = batch_end;
|
||||
}
|
||||
|
||||
return true;
|
||||
return nPast;
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -251,22 +139,16 @@ static std::string::size_type stringsOverlap(const std::string &s, const std::st
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx) {
|
||||
void LLModel::generateResponse(
|
||||
const ResponseCallback &responseCallback,
|
||||
const PromptContext &promptCtx,
|
||||
int32_t nPast
|
||||
) {
|
||||
static const char *stopSequences[] {
|
||||
"### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context",
|
||||
"### System", "### Instruction", "### Human", "### User", "### Response", "### Assistant", "### Context",
|
||||
"<|im_start|>", "<|im_end|>", "<|endoftext|>",
|
||||
};
|
||||
|
||||
// Don't even start if there is no room
|
||||
if (!promptCtx.n_predict)
|
||||
return;
|
||||
if (!allowContextShift && promptCtx.n_past >= contextLength()) {
|
||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << contextLength()
|
||||
<< "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
initSampler(promptCtx);
|
||||
|
||||
std::string cachedResponse;
|
||||
@@ -281,25 +163,20 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
||||
cachedTokens.push_back(new_tok.value());
|
||||
cachedResponse += new_piece;
|
||||
|
||||
auto accept = [this, &promptCtx, &new_tok, allowContextShift]() -> bool {
|
||||
auto accept = [this, &promptCtx, &new_tok, &nPast] {
|
||||
// Shift context if out of space
|
||||
if (promptCtx.n_past >= contextLength()) {
|
||||
(void)allowContextShift;
|
||||
assert(allowContextShift);
|
||||
shiftContext(promptCtx);
|
||||
assert(promptCtx.n_past < contextLength());
|
||||
if (nPast >= contextLength()) {
|
||||
shiftContext(promptCtx, &nPast);
|
||||
assert(nPast < contextLength());
|
||||
}
|
||||
|
||||
// Accept the token
|
||||
Token tok = std::exchange(new_tok, std::nullopt).value();
|
||||
if (!evalTokens(promptCtx, { &tok, 1 })) {
|
||||
// TODO(jared): raise an exception
|
||||
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
|
||||
return false;
|
||||
}
|
||||
if (!evalTokens(nPast, { &tok, 1 }))
|
||||
throw std::runtime_error("An internal error was encountered during response generation.");
|
||||
|
||||
appendInputToken(promptCtx, tok);
|
||||
return true;
|
||||
appendInputToken(tok);
|
||||
nPast++;
|
||||
};
|
||||
|
||||
// Check for EOS
|
||||
@@ -336,13 +213,6 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
||||
lengthLimit = cachedResponse.size() - new_piece.size();
|
||||
}
|
||||
|
||||
// Optionally stop if the context will run out
|
||||
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= contextLength()) {
|
||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx="
|
||||
<< contextLength() << "\n";
|
||||
stop = true;
|
||||
}
|
||||
|
||||
// Empty the cache, up to the length limit
|
||||
std::string::size_type responseLength = 0;
|
||||
while (!cachedTokens.empty()) {
|
||||
@@ -359,8 +229,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
||||
cachedResponse.erase(cachedResponse.begin(), cachedResponse.begin() + piece.size());
|
||||
|
||||
// Accept the token, if needed (not cached)
|
||||
if (cachedTokens.empty() && new_tok && !accept())
|
||||
return;
|
||||
if (cachedTokens.empty() && new_tok)
|
||||
accept();
|
||||
|
||||
// Send the token
|
||||
if (!responseCallback(tok, piece) || ++n_predicted >= promptCtx.n_predict) {
|
||||
@@ -379,8 +249,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
||||
assert(!cachedTokens.empty() && cachedTokens.back() == new_tok);
|
||||
if (stop) {
|
||||
cachedTokens.pop_back();
|
||||
} else if (!accept()) {
|
||||
return;
|
||||
} else {
|
||||
accept();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -396,8 +266,6 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
||||
auto discard_start = inp.end() - cachedTokens.size();
|
||||
assert(std::equal(discard_start, inp.end(), cachedTokens.begin()));
|
||||
#endif
|
||||
|
||||
promptCtx.n_past -= cachedTokens.size();
|
||||
}
|
||||
|
||||
void LLModel::embed(
|
||||
|
Reference in New Issue
Block a user