fix chat-style prompt templates (#1970)

Also use a new version of Mistral OpenOrca.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-02-21 15:45:32 -05:00
committed by GitHub
parent b8f5c74f40
commit 4fc4d94be4
22 changed files with 429 additions and 307 deletions

View File

@@ -814,8 +814,10 @@ std::vector<float> Bert::embedding(const std::string &text)
return finalEmbeddings;
}
std::vector<LLModel::Token> Bert::tokenize(PromptContext &, const std::string &str) const
std::vector<LLModel::Token> Bert::tokenize(PromptContext &ctx, const std::string &str, bool special) const
{
(void)ctx;
(void)special;
return ::bert_tokenize(d_ptr->ctx, str.c_str());
}

View File

@@ -33,12 +33,13 @@ private:
std::unique_ptr<BertPrivate> d_ptr;
protected:
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
Token sampleToken(PromptContext &ctx) const override;
std::string tokenToString(Token) const override;
std::string tokenToString(Token id) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override;
const std::vector<Token>& endTokens() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override { return true; }
};
#endif // BERT_H

View File

@@ -737,8 +737,10 @@ size_t GPTJ::restoreState(const uint8_t *src)
return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src);
}
std::vector<LLModel::Token> GPTJ::tokenize(PromptContext &, const std::string &str) const
std::vector<LLModel::Token> GPTJ::tokenize(PromptContext &ctx, const std::string &str, bool special) const
{
(void)ctx;
(void)special;
return ::gpt_tokenize(d_ptr->vocab, str);
}

View File

@@ -30,12 +30,13 @@ private:
GPTJPrivate *d_ptr;
protected:
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
Token sampleToken(PromptContext &ctx) const override;
std::string tokenToString(Token) const override;
std::string tokenToString(Token id) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override;
const std::vector<Token>& endTokens() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override { return false; }
};
#endif // GPTJ_H

View File

@@ -6,38 +6,29 @@
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <vector>
#include <iomanip>
#include <iostream>
#if defined(_WIN32) && defined(_MSC_VER)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <io.h>
#include <stdio.h>
#else
#include <unistd.h>
#endif
#include <map>
#include <random>
#include <sstream>
#include <stdexcept>
#include <string>
#include <thread>
#include <unordered_set>
#include <vector>
#include <llama.h>
#include <ggml.h>
#ifdef GGML_USE_KOMPUTE
#include "ggml-kompute.h"
#include <ggml-kompute.h>
#endif
using namespace std::string_literals;
// Maximum supported GGUF version
static constexpr int GGUF_VER_MAX = 3;
namespace {
const char *modelType_ = "LLaMA";
}
static const char * const modelType_ = "LLaMA";
static bool llama_verbose() {
const char* var = getenv("GPT4ALL_VERBOSE_LLAMACPP");
@@ -96,6 +87,56 @@ static int llama_sample_top_p_top_k(
return llama_sample_token(ctx, &candidates_p);
}
std::string get_arch_name(gguf_context *ctx_gguf) {
std::string arch_name;
const int kid = gguf_find_key(ctx_gguf, "general.architecture");
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
if (ktype != (GGUF_TYPE_STRING)) {
throw std::runtime_error("ERROR: Can't get general architecture from gguf file.");
}
return gguf_get_val_str(ctx_gguf, kid);
}
static gguf_context *load_gguf(const char *fname) {
struct gguf_init_params params = {
/*.no_alloc = */ true,
/*.ctx = */ nullptr,
};
gguf_context *ctx = gguf_init_from_file(fname, params);
if (!ctx) {
std::cerr << __func__ << ": gguf_init_from_file failed\n";
return nullptr;
}
int gguf_ver = gguf_get_version(ctx);
if (gguf_ver > GGUF_VER_MAX) {
std::cerr << __func__ << ": unsupported gguf version: " << gguf_ver << "\n";
gguf_free(ctx);
return nullptr;
}
return ctx;
}
static int32_t get_arch_key_u32(std::string const &modelPath, std::string const &archKey) {
auto * ctx = load_gguf(modelPath.c_str());
auto arch = get_arch_name(ctx);
int32_t value = -1;
if (ctx) {
auto key = arch + "." + archKey;
int keyidx = gguf_find_key(ctx, key.c_str());
if (keyidx != -1) {
value = gguf_get_val_u32(ctx, keyidx);
} else {
std::cerr << __func__ << ": " << key << "not found in " << modelPath << "\n";
}
}
gguf_free(ctx);
return value;
}
struct LLamaPrivate {
const std::string modelPath;
bool modelLoaded;
@@ -148,6 +189,42 @@ size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx, int ngl)
return filesize + est_kvcache_size;
}
bool LLamaModel::isModelBlacklisted(const std::string &modelPath) {
auto * ctx = load_gguf(modelPath.c_str());
if (!ctx) {
std::cerr << __func__ << ": failed to load " << modelPath << "\n";
return false;
}
auto get_key = [ctx, &modelPath](const char *name) {
int keyidx = gguf_find_key(ctx, name);
if (keyidx == -1) {
throw std::logic_error(name + " not found in "s + modelPath);
}
return keyidx;
};
bool res = false;
try {
std::string name(gguf_get_val_str(ctx, get_key("general.name")));
int token_idx = get_key("tokenizer.ggml.tokens");
int n_vocab = gguf_get_arr_n(ctx, token_idx);
// check for known bad models
if (name == "open-orca_mistral-7b-openorca"
&& n_vocab == 32002
&& gguf_get_arr_str(ctx, token_idx, 32000) == "<dummy32000>"s // should be <|im_end|>
) {
res = true;
}
} catch (const std::logic_error &e) {
std::cerr << __func__ << ": " << e.what() << "\n";
}
gguf_free(ctx);
return res;
}
bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
{
d_ptr->modelLoaded = false;
@@ -290,12 +367,13 @@ size_t LLamaModel::restoreState(const uint8_t *src)
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
}
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str) const
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special) const
{
const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos(d_ptr->model));
std::vector<LLModel::Token> fres(str.size()+4);
// TODO(cebtenzzre): we may want to use special=true here to process special tokens
auto fres_len = llama_tokenize(d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), useBOS, false);
const bool wantBOS = ctx.n_past == 0 && ctx.tokens.empty();
const bool useBOS = wantBOS && shouldAddBOS();
auto strCat = wantBOS && !special ? " " + str : str; // insert leading space ourselves, llama.cpp fork doesn't anymore
std::vector<LLModel::Token> fres(strCat.size()+4);
auto fres_len = llama_tokenize(d_ptr->model, strCat.c_str(), strCat.length(), fres.data(), fres.size(), useBOS, special);
fres.resize(fres_len);
return fres;
}
@@ -349,55 +427,10 @@ const std::vector<LLModel::Token> &LLamaModel::endTokens() const
return d_ptr->end_tokens;
}
std::string get_arch_name(gguf_context *ctx_gguf) {
std::string arch_name;
const int kid = gguf_find_key(ctx_gguf, "general.architecture");
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
if (ktype != (GGUF_TYPE_STRING)) {
throw std::runtime_error("ERROR: Can't get general architecture from gguf file.");
}
return gguf_get_val_str(ctx_gguf, kid);
}
static gguf_context *load_gguf(const char *fname, std::string &arch) {
struct gguf_init_params params = {
/*.no_alloc = */ true,
/*.ctx = */ nullptr,
};
gguf_context *ctx = gguf_init_from_file(fname, params);
if (!ctx) {
std::cerr << __func__ << ": gguf_init_from_file failed\n";
return nullptr;
}
int gguf_ver = gguf_get_version(ctx);
if (gguf_ver > GGUF_VER_MAX) {
std::cerr << __func__ << ": unsupported gguf version: " << gguf_ver << "\n";
gguf_free(ctx);
return nullptr;
}
arch = get_arch_name(ctx);
return ctx;
}
static int32_t get_arch_key_u32(std::string const &modelPath, std::string const &archKey) {
std::string arch;
auto * ctx = load_gguf(modelPath.c_str(), arch);
int32_t value = -1;
if (ctx) {
auto key = arch + "." + archKey;
int keyidx = gguf_find_key(ctx, key.c_str());
if (keyidx != -1) {
value = gguf_get_val_u32(ctx, keyidx);
} else {
std::cerr << __func__ << ": " << key << "not found in " << modelPath << "\n";
}
}
gguf_free(ctx);
return value;
bool LLamaModel::shouldAddBOS() const
{
int add_bos = llama_add_bos_token(d_ptr->model);
return add_bos != -1 ? bool(add_bos) : llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_SPM;
}
int32_t LLamaModel::maxContextLength(std::string const &modelPath) const
@@ -513,8 +546,8 @@ DLL_EXPORT const char *get_build_variant() {
}
DLL_EXPORT bool magic_match(const char *fname) {
std::string arch;
auto * ctx = load_gguf(fname, arch);
auto * ctx = load_gguf(fname);
auto arch = get_arch_name(ctx);
bool valid = true;

View File

@@ -19,6 +19,7 @@ public:
bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override;
bool isModelBlacklisted(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
size_t stateSize() const override;
@@ -27,7 +28,7 @@ public:
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;
std::vector<GPUDevice> availableGPUDevices(size_t memoryRequired) const override;
bool initializeGPUDevice(size_t memoryRequired, const std::string& name) const override;
bool initializeGPUDevice(size_t memoryRequired, const std::string &name) const override;
bool initializeGPUDevice(int device, std::string *unavail_reason) const override;
bool hasGPUDevice() override;
bool usingGPUDevice() override;
@@ -36,12 +37,13 @@ private:
std::unique_ptr<LLamaPrivate> d_ptr;
protected:
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
std::string tokenToString(Token) const override;
Token sampleToken(PromptContext& ctx) const override;
bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override;
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
std::string tokenToString(Token id) const override;
Token sampleToken(PromptContext &ctx) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
int32_t contextLength() const override;
const std::vector<Token>& endTokens() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override;
int32_t maxContextLength(std::string const &modelPath) const override;
int32_t layerCount(std::string const &modelPath) const override;

View File

@@ -29,23 +29,23 @@ public:
class Implementation {
public:
Implementation(Dlhandle&&);
Implementation(const Implementation&) = delete;
Implementation(Implementation&&);
Implementation(Dlhandle &&);
Implementation(const Implementation &) = delete;
Implementation(Implementation &&);
~Implementation();
std::string_view modelType() const { return m_modelType; }
std::string_view buildVariant() const { return m_buildVariant; }
static bool isImplementation(const Dlhandle&);
static const std::vector<Implementation>& implementationList();
static const Implementation *implementation(const char *fname, const std::string& buildVariant);
static bool isImplementation(const Dlhandle &dl);
static const std::vector<Implementation> &implementationList();
static const Implementation *implementation(const char *fname, const std::string &buildVariant);
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto", int n_ctx = 2048);
static std::vector<GPUDevice> availableGPUDevices();
static int32_t maxContextLength(const std::string &modelPath);
static int32_t layerCount(const std::string &modelPath);
static void setImplementationsSearchPath(const std::string& path);
static const std::string& implementationsSearchPath();
static void setImplementationsSearchPath(const std::string &path);
static const std::string &implementationsSearchPath();
private:
static LLModel *constructDefaultLlama();
@@ -82,26 +82,30 @@ public:
virtual bool supportsEmbedding() const = 0;
virtual bool supportsCompletion() const = 0;
virtual bool loadModel(const std::string &modelPath, int n_ctx, int ngl) = 0;
virtual bool isModelBlacklisted(const std::string &modelPath) { (void)modelPath; return false; };
virtual bool isModelLoaded() const = 0;
virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0;
virtual size_t stateSize() const { return 0; }
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }
virtual size_t saveState(uint8_t *dest) const { (void)dest; return 0; }
virtual size_t restoreState(const uint8_t *src) { (void)src; return 0; }
// This method requires the model to return true from supportsCompletion otherwise it will throw
// an error
virtual void prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &ctx);
PromptContext &ctx,
bool special = false,
std::string *fakeReply = nullptr);
virtual std::vector<float> embedding(const std::string &text);
virtual void setThreadCount(int32_t /*n_threads*/) {}
virtual void setThreadCount(int32_t n_threads) { (void)n_threads; }
virtual int32_t threadCount() const { return 1; }
const Implementation& implementation() const {
const Implementation &implementation() const {
return *m_implementation;
}
@@ -110,7 +114,7 @@ public:
return {};
}
virtual bool initializeGPUDevice(size_t memoryRequired, const std::string& name) const {
virtual bool initializeGPUDevice(size_t memoryRequired, const std::string &name) const {
(void)memoryRequired;
(void)name;
return false;
@@ -132,12 +136,13 @@ public:
protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(PromptContext &, const std::string&) const = 0;
virtual std::string tokenToString(Token) const = 0;
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0;
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
virtual int32_t contextLength() const = 0;
virtual const std::vector<Token>& endTokens() const = 0;
virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0;
virtual int32_t maxContextLength(std::string const &modelPath) const
{
@@ -166,6 +171,15 @@ protected:
return true;
}
void decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx,
std::vector<Token> embd_inp);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx);
private:
friend class LLMImplementation;
};

View File

@@ -1,8 +1,9 @@
#include "llmodel_c.h"
#include "llmodel.h"
#include <cstring>
#include <cerrno>
#include <cstring>
#include <iostream>
#include <utility>
struct LLModelWrapper {
@@ -56,7 +57,14 @@ size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_c
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx, int ngl)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->loadModel(model_path, n_ctx, ngl);
std::string modelPath(model_path);
if (wrapper->llModel->isModelBlacklisted(modelPath)) {
size_t slash = modelPath.find_last_of("/\\");
auto basename = slash == std::string::npos ? modelPath : modelPath.substr(slash + 1);
std::cerr << "warning: model '" << basename << "' is out-of-date, please check for an updated version\n";
}
return wrapper->llModel->loadModel(modelPath, n_ctx, ngl);
}
bool llmodel_isModelLoaded(llmodel_model model)
@@ -100,10 +108,12 @@ bool recalculate_wrapper(bool is_recalculating, void *user_data) {
}
void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
llmodel_prompt_context *ctx)
llmodel_prompt_context *ctx,
bool special)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
@@ -131,7 +141,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
wrapper->promptContext.contextErase = ctx->context_erase;
// Call the C++ prompt method
wrapper->llModel->prompt(prompt, prompt_func, response_func, recalc_func, wrapper->promptContext);
wrapper->llModel->prompt(prompt, prompt_template, prompt_func, response_func, recalc_func, wrapper->promptContext, special);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies

View File

@@ -163,16 +163,20 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
* Generate a response using the model.
* @param model A pointer to the llmodel_model instance.
* @param prompt A string representing the input prompt.
* @param prompt_template A string representing the input prompt template.
* @param prompt_callback A callback function for handling the processing of prompt.
* @param response_callback A callback function for handling the generated response.
* @param recalculate_callback A callback function for handling recalculation requests.
* @param special True if special tokens in the prompt should be processed, false otherwise.
* @param ctx A pointer to the llmodel_prompt_context structure.
*/
void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
llmodel_prompt_context *ctx);
llmodel_prompt_context *ctx,
bool special);
/**
* Generate an embedding using the model.

View File

@@ -2,11 +2,20 @@
#include <cassert>
#include <iostream>
#include <regex>
#include <unordered_set>
// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is)
void LLModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate) {
size_t i = 0;
promptCtx.n_past = 0;
int n_keep = shouldAddBOS();
const int32_t n_discard = (promptCtx.n_ctx - n_keep) * promptCtx.contextErase;
// Erase the first percentage of context from the tokens
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
size_t i = n_keep;
promptCtx.n_past = n_keep;
while (i < promptCtx.tokens.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
std::vector<int32_t> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
@@ -26,11 +35,36 @@ stop_generating:
recalculate(false);
}
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err) {
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))");
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex);
placeholders.clear();
placeholders.insert(placeholders.end(), it, std::sregex_iterator());
if (placeholders.size() > 2) {
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size());
return false;
}
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
return false;
}
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
return false;
}
return true;
}
void LLModel::prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx)
PromptContext &promptCtx,
bool special,
std::string *fakeReply)
{
if (!isModelLoaded()) {
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
@@ -38,15 +72,86 @@ void LLModel::prompt(const std::string &prompt,
}
if (!supportsCompletion()) {
std::string errorMessage = "ERROR: this model does not support text completion or chat!\n";
std::string errorMessage = "ERROR: this model does not support text completion or chat!";
responseCallback(-1, errorMessage);
std::cerr << implementation().modelType() << errorMessage;
std::cerr << implementation().modelType() << " " << errorMessage << "\n";
return;
}
// tokenize the prompt
std::vector<Token> embd_inp = tokenize(promptCtx, prompt);
// parse the prompt template
std::vector<std::smatch> placeholders;
{
std::string err;
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
responseCallback(-1, err);
std::cerr << err << "\n";
return;
}
}
auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
// tokenize the user prompt
std::vector<Token> embd_inp;
if (placeholders.empty()) {
// this is unusual, but well-defined
std::cerr << __func__ << ": prompt template has no placeholder\n";
embd_inp = tokenize(promptCtx, promptTemplate, true);
} else {
// template: beginning of user prompt
const auto &phUser = placeholders[0];
std::string userPrefix(phUser.prefix());
if (!userPrefix.empty()) {
embd_inp = tokenize(promptCtx, userPrefix, true);
promptCtx.n_past += embd_inp.size();
}
// user input (shouldn't have special token processing)
auto tokens = tokenize(promptCtx, prompt, special);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
promptCtx.n_past += tokens.size();
// template: end of user prompt + start of assistant prompt
size_t start = phUser.position() + phUser.length();
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
auto userToAsst = promptTemplate.substr(start, end - start);
if (!userToAsst.empty()) {
tokens = tokenize(promptCtx, userToAsst, true);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
promptCtx.n_past += tokens.size();
}
}
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
// decode the user prompt
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
// decode the assistant's reply, either generated or spoofed
if (fakeReply == nullptr) {
generateResponse(responseCallback, recalculateCallback, promptCtx);
} else {
embd_inp = tokenize(promptCtx, *fakeReply, false);
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
}
// decode the rest of the prompt template
if (placeholders.size() >= 2) {
// template: end of assistant prompt
size_t start = placeholders[1].position() + placeholders[1].length();
auto asstSuffix = promptTemplate.substr(start);
if (!asstSuffix.empty()) {
embd_inp = tokenize(promptCtx, asstSuffix, true);
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
}
}
}
void LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx,
std::vector<Token> embd_inp) {
// save the context size
promptCtx.n_ctx = contextLength();
@@ -69,11 +174,6 @@ void LLModel::prompt(const std::string &prompt,
// Check if the context has run out...
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens...
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
}
@@ -94,7 +194,11 @@ void LLModel::prompt(const std::string &prompt,
}
i = batch_end;
}
}
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx) {
std::string cachedResponse;
std::vector<Token> cachedTokens;
std::unordered_set<std::string> reversePrompts
@@ -108,11 +212,6 @@ void LLModel::prompt(const std::string &prompt,
// Check if the context has run out...
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens...
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
}
@@ -165,8 +264,9 @@ void LLModel::prompt(const std::string &prompt,
}
}
std::vector<float> LLModel::embedding(const std::string &/*text*/)
std::vector<float> LLModel::embedding(const std::string &text)
{
(void)text;
if (!supportsCompletion()) {
std::string errorMessage = "ERROR: this model does not support generating embeddings!\n";
std::cerr << implementation().modelType() << errorMessage;