mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-05 02:20:28 +00:00
fix chat-style prompt templates (#1970)
Also use a new version of Mistral OpenOrca. Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -814,8 +814,10 @@ std::vector<float> Bert::embedding(const std::string &text)
|
||||
return finalEmbeddings;
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> Bert::tokenize(PromptContext &, const std::string &str) const
|
||||
std::vector<LLModel::Token> Bert::tokenize(PromptContext &ctx, const std::string &str, bool special) const
|
||||
{
|
||||
(void)ctx;
|
||||
(void)special;
|
||||
return ::bert_tokenize(d_ptr->ctx, str.c_str());
|
||||
}
|
||||
|
||||
|
@@ -33,12 +33,13 @@ private:
|
||||
std::unique_ptr<BertPrivate> d_ptr;
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
|
||||
Token sampleToken(PromptContext &ctx) const override;
|
||||
std::string tokenToString(Token) const override;
|
||||
std::string tokenToString(Token id) const override;
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||
int32_t contextLength() const override;
|
||||
const std::vector<Token>& endTokens() const override;
|
||||
const std::vector<Token> &endTokens() const override;
|
||||
bool shouldAddBOS() const override { return true; }
|
||||
};
|
||||
|
||||
#endif // BERT_H
|
||||
|
@@ -737,8 +737,10 @@ size_t GPTJ::restoreState(const uint8_t *src)
|
||||
return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src);
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> GPTJ::tokenize(PromptContext &, const std::string &str) const
|
||||
std::vector<LLModel::Token> GPTJ::tokenize(PromptContext &ctx, const std::string &str, bool special) const
|
||||
{
|
||||
(void)ctx;
|
||||
(void)special;
|
||||
return ::gpt_tokenize(d_ptr->vocab, str);
|
||||
}
|
||||
|
||||
|
@@ -30,12 +30,13 @@ private:
|
||||
GPTJPrivate *d_ptr;
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
|
||||
Token sampleToken(PromptContext &ctx) const override;
|
||||
std::string tokenToString(Token) const override;
|
||||
std::string tokenToString(Token id) const override;
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||
int32_t contextLength() const override;
|
||||
const std::vector<Token>& endTokens() const override;
|
||||
const std::vector<Token> &endTokens() const override;
|
||||
bool shouldAddBOS() const override { return false; }
|
||||
};
|
||||
|
||||
#endif // GPTJ_H
|
||||
|
Submodule gpt4all-backend/llama.cpp-mainline updated: 7d4ced8505...b61ee89fca
@@ -6,38 +6,29 @@
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#if defined(_WIN32) && defined(_MSC_VER)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <io.h>
|
||||
#include <stdio.h>
|
||||
#else
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include <map>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include <llama.h>
|
||||
#include <ggml.h>
|
||||
|
||||
#ifdef GGML_USE_KOMPUTE
|
||||
#include "ggml-kompute.h"
|
||||
#include <ggml-kompute.h>
|
||||
#endif
|
||||
|
||||
using namespace std::string_literals;
|
||||
|
||||
// Maximum supported GGUF version
|
||||
static constexpr int GGUF_VER_MAX = 3;
|
||||
|
||||
namespace {
|
||||
const char *modelType_ = "LLaMA";
|
||||
}
|
||||
static const char * const modelType_ = "LLaMA";
|
||||
|
||||
static bool llama_verbose() {
|
||||
const char* var = getenv("GPT4ALL_VERBOSE_LLAMACPP");
|
||||
@@ -96,6 +87,56 @@ static int llama_sample_top_p_top_k(
|
||||
return llama_sample_token(ctx, &candidates_p);
|
||||
}
|
||||
|
||||
std::string get_arch_name(gguf_context *ctx_gguf) {
|
||||
std::string arch_name;
|
||||
const int kid = gguf_find_key(ctx_gguf, "general.architecture");
|
||||
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
|
||||
if (ktype != (GGUF_TYPE_STRING)) {
|
||||
throw std::runtime_error("ERROR: Can't get general architecture from gguf file.");
|
||||
}
|
||||
return gguf_get_val_str(ctx_gguf, kid);
|
||||
}
|
||||
|
||||
static gguf_context *load_gguf(const char *fname) {
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ nullptr,
|
||||
};
|
||||
gguf_context *ctx = gguf_init_from_file(fname, params);
|
||||
if (!ctx) {
|
||||
std::cerr << __func__ << ": gguf_init_from_file failed\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int gguf_ver = gguf_get_version(ctx);
|
||||
if (gguf_ver > GGUF_VER_MAX) {
|
||||
std::cerr << __func__ << ": unsupported gguf version: " << gguf_ver << "\n";
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
static int32_t get_arch_key_u32(std::string const &modelPath, std::string const &archKey) {
|
||||
auto * ctx = load_gguf(modelPath.c_str());
|
||||
auto arch = get_arch_name(ctx);
|
||||
|
||||
int32_t value = -1;
|
||||
if (ctx) {
|
||||
auto key = arch + "." + archKey;
|
||||
int keyidx = gguf_find_key(ctx, key.c_str());
|
||||
if (keyidx != -1) {
|
||||
value = gguf_get_val_u32(ctx, keyidx);
|
||||
} else {
|
||||
std::cerr << __func__ << ": " << key << "not found in " << modelPath << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
gguf_free(ctx);
|
||||
return value;
|
||||
}
|
||||
|
||||
struct LLamaPrivate {
|
||||
const std::string modelPath;
|
||||
bool modelLoaded;
|
||||
@@ -148,6 +189,42 @@ size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx, int ngl)
|
||||
return filesize + est_kvcache_size;
|
||||
}
|
||||
|
||||
bool LLamaModel::isModelBlacklisted(const std::string &modelPath) {
|
||||
auto * ctx = load_gguf(modelPath.c_str());
|
||||
if (!ctx) {
|
||||
std::cerr << __func__ << ": failed to load " << modelPath << "\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto get_key = [ctx, &modelPath](const char *name) {
|
||||
int keyidx = gguf_find_key(ctx, name);
|
||||
if (keyidx == -1) {
|
||||
throw std::logic_error(name + " not found in "s + modelPath);
|
||||
}
|
||||
return keyidx;
|
||||
};
|
||||
|
||||
bool res = false;
|
||||
try {
|
||||
std::string name(gguf_get_val_str(ctx, get_key("general.name")));
|
||||
int token_idx = get_key("tokenizer.ggml.tokens");
|
||||
int n_vocab = gguf_get_arr_n(ctx, token_idx);
|
||||
|
||||
// check for known bad models
|
||||
if (name == "open-orca_mistral-7b-openorca"
|
||||
&& n_vocab == 32002
|
||||
&& gguf_get_arr_str(ctx, token_idx, 32000) == "<dummy32000>"s // should be <|im_end|>
|
||||
) {
|
||||
res = true;
|
||||
}
|
||||
} catch (const std::logic_error &e) {
|
||||
std::cerr << __func__ << ": " << e.what() << "\n";
|
||||
}
|
||||
|
||||
gguf_free(ctx);
|
||||
return res;
|
||||
}
|
||||
|
||||
bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
||||
{
|
||||
d_ptr->modelLoaded = false;
|
||||
@@ -290,12 +367,13 @@ size_t LLamaModel::restoreState(const uint8_t *src)
|
||||
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str) const
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special) const
|
||||
{
|
||||
const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos(d_ptr->model));
|
||||
std::vector<LLModel::Token> fres(str.size()+4);
|
||||
// TODO(cebtenzzre): we may want to use special=true here to process special tokens
|
||||
auto fres_len = llama_tokenize(d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), useBOS, false);
|
||||
const bool wantBOS = ctx.n_past == 0 && ctx.tokens.empty();
|
||||
const bool useBOS = wantBOS && shouldAddBOS();
|
||||
auto strCat = wantBOS && !special ? " " + str : str; // insert leading space ourselves, llama.cpp fork doesn't anymore
|
||||
std::vector<LLModel::Token> fres(strCat.size()+4);
|
||||
auto fres_len = llama_tokenize(d_ptr->model, strCat.c_str(), strCat.length(), fres.data(), fres.size(), useBOS, special);
|
||||
fres.resize(fres_len);
|
||||
return fres;
|
||||
}
|
||||
@@ -349,55 +427,10 @@ const std::vector<LLModel::Token> &LLamaModel::endTokens() const
|
||||
return d_ptr->end_tokens;
|
||||
}
|
||||
|
||||
std::string get_arch_name(gguf_context *ctx_gguf) {
|
||||
std::string arch_name;
|
||||
const int kid = gguf_find_key(ctx_gguf, "general.architecture");
|
||||
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
|
||||
if (ktype != (GGUF_TYPE_STRING)) {
|
||||
throw std::runtime_error("ERROR: Can't get general architecture from gguf file.");
|
||||
}
|
||||
return gguf_get_val_str(ctx_gguf, kid);
|
||||
}
|
||||
|
||||
static gguf_context *load_gguf(const char *fname, std::string &arch) {
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ nullptr,
|
||||
};
|
||||
gguf_context *ctx = gguf_init_from_file(fname, params);
|
||||
if (!ctx) {
|
||||
std::cerr << __func__ << ": gguf_init_from_file failed\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int gguf_ver = gguf_get_version(ctx);
|
||||
if (gguf_ver > GGUF_VER_MAX) {
|
||||
std::cerr << __func__ << ": unsupported gguf version: " << gguf_ver << "\n";
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
arch = get_arch_name(ctx);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
static int32_t get_arch_key_u32(std::string const &modelPath, std::string const &archKey) {
|
||||
std::string arch;
|
||||
auto * ctx = load_gguf(modelPath.c_str(), arch);
|
||||
|
||||
int32_t value = -1;
|
||||
if (ctx) {
|
||||
auto key = arch + "." + archKey;
|
||||
int keyidx = gguf_find_key(ctx, key.c_str());
|
||||
if (keyidx != -1) {
|
||||
value = gguf_get_val_u32(ctx, keyidx);
|
||||
} else {
|
||||
std::cerr << __func__ << ": " << key << "not found in " << modelPath << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
gguf_free(ctx);
|
||||
return value;
|
||||
bool LLamaModel::shouldAddBOS() const
|
||||
{
|
||||
int add_bos = llama_add_bos_token(d_ptr->model);
|
||||
return add_bos != -1 ? bool(add_bos) : llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_SPM;
|
||||
}
|
||||
|
||||
int32_t LLamaModel::maxContextLength(std::string const &modelPath) const
|
||||
@@ -513,8 +546,8 @@ DLL_EXPORT const char *get_build_variant() {
|
||||
}
|
||||
|
||||
DLL_EXPORT bool magic_match(const char *fname) {
|
||||
std::string arch;
|
||||
auto * ctx = load_gguf(fname, arch);
|
||||
auto * ctx = load_gguf(fname);
|
||||
auto arch = get_arch_name(ctx);
|
||||
|
||||
bool valid = true;
|
||||
|
||||
|
@@ -19,6 +19,7 @@ public:
|
||||
bool supportsEmbedding() const override { return false; }
|
||||
bool supportsCompletion() const override { return true; }
|
||||
bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override;
|
||||
bool isModelBlacklisted(const std::string &modelPath) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
|
||||
size_t stateSize() const override;
|
||||
@@ -27,7 +28,7 @@ public:
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() const override;
|
||||
std::vector<GPUDevice> availableGPUDevices(size_t memoryRequired) const override;
|
||||
bool initializeGPUDevice(size_t memoryRequired, const std::string& name) const override;
|
||||
bool initializeGPUDevice(size_t memoryRequired, const std::string &name) const override;
|
||||
bool initializeGPUDevice(int device, std::string *unavail_reason) const override;
|
||||
bool hasGPUDevice() override;
|
||||
bool usingGPUDevice() override;
|
||||
@@ -36,12 +37,13 @@ private:
|
||||
std::unique_ptr<LLamaPrivate> d_ptr;
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||
std::string tokenToString(Token) const override;
|
||||
Token sampleToken(PromptContext& ctx) const override;
|
||||
bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override;
|
||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
|
||||
std::string tokenToString(Token id) const override;
|
||||
Token sampleToken(PromptContext &ctx) const override;
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||
int32_t contextLength() const override;
|
||||
const std::vector<Token>& endTokens() const override;
|
||||
const std::vector<Token> &endTokens() const override;
|
||||
bool shouldAddBOS() const override;
|
||||
|
||||
int32_t maxContextLength(std::string const &modelPath) const override;
|
||||
int32_t layerCount(std::string const &modelPath) const override;
|
||||
|
@@ -29,23 +29,23 @@ public:
|
||||
|
||||
class Implementation {
|
||||
public:
|
||||
Implementation(Dlhandle&&);
|
||||
Implementation(const Implementation&) = delete;
|
||||
Implementation(Implementation&&);
|
||||
Implementation(Dlhandle &&);
|
||||
Implementation(const Implementation &) = delete;
|
||||
Implementation(Implementation &&);
|
||||
~Implementation();
|
||||
|
||||
std::string_view modelType() const { return m_modelType; }
|
||||
std::string_view buildVariant() const { return m_buildVariant; }
|
||||
|
||||
static bool isImplementation(const Dlhandle&);
|
||||
static const std::vector<Implementation>& implementationList();
|
||||
static const Implementation *implementation(const char *fname, const std::string& buildVariant);
|
||||
static bool isImplementation(const Dlhandle &dl);
|
||||
static const std::vector<Implementation> &implementationList();
|
||||
static const Implementation *implementation(const char *fname, const std::string &buildVariant);
|
||||
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto", int n_ctx = 2048);
|
||||
static std::vector<GPUDevice> availableGPUDevices();
|
||||
static int32_t maxContextLength(const std::string &modelPath);
|
||||
static int32_t layerCount(const std::string &modelPath);
|
||||
static void setImplementationsSearchPath(const std::string& path);
|
||||
static const std::string& implementationsSearchPath();
|
||||
static void setImplementationsSearchPath(const std::string &path);
|
||||
static const std::string &implementationsSearchPath();
|
||||
|
||||
private:
|
||||
static LLModel *constructDefaultLlama();
|
||||
@@ -82,26 +82,30 @@ public:
|
||||
virtual bool supportsEmbedding() const = 0;
|
||||
virtual bool supportsCompletion() const = 0;
|
||||
virtual bool loadModel(const std::string &modelPath, int n_ctx, int ngl) = 0;
|
||||
virtual bool isModelBlacklisted(const std::string &modelPath) { (void)modelPath; return false; };
|
||||
virtual bool isModelLoaded() const = 0;
|
||||
virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0;
|
||||
virtual size_t stateSize() const { return 0; }
|
||||
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
|
||||
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }
|
||||
virtual size_t saveState(uint8_t *dest) const { (void)dest; return 0; }
|
||||
virtual size_t restoreState(const uint8_t *src) { (void)src; return 0; }
|
||||
|
||||
// This method requires the model to return true from supportsCompletion otherwise it will throw
|
||||
// an error
|
||||
virtual void prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx);
|
||||
PromptContext &ctx,
|
||||
bool special = false,
|
||||
std::string *fakeReply = nullptr);
|
||||
|
||||
virtual std::vector<float> embedding(const std::string &text);
|
||||
|
||||
virtual void setThreadCount(int32_t /*n_threads*/) {}
|
||||
virtual void setThreadCount(int32_t n_threads) { (void)n_threads; }
|
||||
virtual int32_t threadCount() const { return 1; }
|
||||
|
||||
const Implementation& implementation() const {
|
||||
const Implementation &implementation() const {
|
||||
return *m_implementation;
|
||||
}
|
||||
|
||||
@@ -110,7 +114,7 @@ public:
|
||||
return {};
|
||||
}
|
||||
|
||||
virtual bool initializeGPUDevice(size_t memoryRequired, const std::string& name) const {
|
||||
virtual bool initializeGPUDevice(size_t memoryRequired, const std::string &name) const {
|
||||
(void)memoryRequired;
|
||||
(void)name;
|
||||
return false;
|
||||
@@ -132,12 +136,13 @@ public:
|
||||
protected:
|
||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||
// 'prompt' above calls these functions
|
||||
virtual std::vector<Token> tokenize(PromptContext &, const std::string&) const = 0;
|
||||
virtual std::string tokenToString(Token) const = 0;
|
||||
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) const = 0;
|
||||
virtual std::string tokenToString(Token id) const = 0;
|
||||
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
||||
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;
|
||||
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
|
||||
virtual int32_t contextLength() const = 0;
|
||||
virtual const std::vector<Token>& endTokens() const = 0;
|
||||
virtual const std::vector<Token> &endTokens() const = 0;
|
||||
virtual bool shouldAddBOS() const = 0;
|
||||
|
||||
virtual int32_t maxContextLength(std::string const &modelPath) const
|
||||
{
|
||||
@@ -166,6 +171,15 @@ protected:
|
||||
return true;
|
||||
}
|
||||
|
||||
void decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp);
|
||||
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx);
|
||||
|
||||
private:
|
||||
friend class LLMImplementation;
|
||||
};
|
||||
|
@@ -1,8 +1,9 @@
|
||||
#include "llmodel_c.h"
|
||||
#include "llmodel.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <cerrno>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
|
||||
struct LLModelWrapper {
|
||||
@@ -56,7 +57,14 @@ size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_c
|
||||
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx, int ngl)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->loadModel(model_path, n_ctx, ngl);
|
||||
|
||||
std::string modelPath(model_path);
|
||||
if (wrapper->llModel->isModelBlacklisted(modelPath)) {
|
||||
size_t slash = modelPath.find_last_of("/\\");
|
||||
auto basename = slash == std::string::npos ? modelPath : modelPath.substr(slash + 1);
|
||||
std::cerr << "warning: model '" << basename << "' is out-of-date, please check for an updated version\n";
|
||||
}
|
||||
return wrapper->llModel->loadModel(modelPath, n_ctx, ngl);
|
||||
}
|
||||
|
||||
bool llmodel_isModelLoaded(llmodel_model model)
|
||||
@@ -100,10 +108,12 @@ bool recalculate_wrapper(bool is_recalculating, void *user_data) {
|
||||
}
|
||||
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
const char *prompt_template,
|
||||
llmodel_prompt_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
llmodel_recalculate_callback recalculate_callback,
|
||||
llmodel_prompt_context *ctx)
|
||||
llmodel_prompt_context *ctx,
|
||||
bool special)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
|
||||
@@ -131,7 +141,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
wrapper->promptContext.contextErase = ctx->context_erase;
|
||||
|
||||
// Call the C++ prompt method
|
||||
wrapper->llModel->prompt(prompt, prompt_func, response_func, recalc_func, wrapper->promptContext);
|
||||
wrapper->llModel->prompt(prompt, prompt_template, prompt_func, response_func, recalc_func, wrapper->promptContext, special);
|
||||
|
||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||
// which involves no copies
|
||||
|
@@ -163,16 +163,20 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
|
||||
* Generate a response using the model.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param prompt A string representing the input prompt.
|
||||
* @param prompt_template A string representing the input prompt template.
|
||||
* @param prompt_callback A callback function for handling the processing of prompt.
|
||||
* @param response_callback A callback function for handling the generated response.
|
||||
* @param recalculate_callback A callback function for handling recalculation requests.
|
||||
* @param special True if special tokens in the prompt should be processed, false otherwise.
|
||||
* @param ctx A pointer to the llmodel_prompt_context structure.
|
||||
*/
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
const char *prompt_template,
|
||||
llmodel_prompt_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
llmodel_recalculate_callback recalculate_callback,
|
||||
llmodel_prompt_context *ctx);
|
||||
llmodel_prompt_context *ctx,
|
||||
bool special);
|
||||
|
||||
/**
|
||||
* Generate an embedding using the model.
|
||||
|
@@ -2,11 +2,20 @@
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
#include <unordered_set>
|
||||
|
||||
// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is)
|
||||
void LLModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate) {
|
||||
size_t i = 0;
|
||||
promptCtx.n_past = 0;
|
||||
int n_keep = shouldAddBOS();
|
||||
const int32_t n_discard = (promptCtx.n_ctx - n_keep) * promptCtx.contextErase;
|
||||
|
||||
// Erase the first percentage of context from the tokens
|
||||
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
|
||||
|
||||
size_t i = n_keep;
|
||||
promptCtx.n_past = n_keep;
|
||||
while (i < promptCtx.tokens.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||
std::vector<int32_t> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||
@@ -26,11 +35,36 @@ stop_generating:
|
||||
recalculate(false);
|
||||
}
|
||||
|
||||
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err) {
|
||||
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))");
|
||||
|
||||
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex);
|
||||
placeholders.clear();
|
||||
placeholders.insert(placeholders.end(), it, std::sregex_iterator());
|
||||
|
||||
if (placeholders.size() > 2) {
|
||||
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size());
|
||||
return false;
|
||||
}
|
||||
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
|
||||
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
|
||||
return false;
|
||||
}
|
||||
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
|
||||
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void LLModel::prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx)
|
||||
PromptContext &promptCtx,
|
||||
bool special,
|
||||
std::string *fakeReply)
|
||||
{
|
||||
if (!isModelLoaded()) {
|
||||
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
|
||||
@@ -38,15 +72,86 @@ void LLModel::prompt(const std::string &prompt,
|
||||
}
|
||||
|
||||
if (!supportsCompletion()) {
|
||||
std::string errorMessage = "ERROR: this model does not support text completion or chat!\n";
|
||||
std::string errorMessage = "ERROR: this model does not support text completion or chat!";
|
||||
responseCallback(-1, errorMessage);
|
||||
std::cerr << implementation().modelType() << errorMessage;
|
||||
std::cerr << implementation().modelType() << " " << errorMessage << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<Token> embd_inp = tokenize(promptCtx, prompt);
|
||||
// parse the prompt template
|
||||
std::vector<std::smatch> placeholders;
|
||||
{
|
||||
std::string err;
|
||||
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
|
||||
responseCallback(-1, err);
|
||||
std::cerr << err << "\n";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
|
||||
|
||||
// tokenize the user prompt
|
||||
std::vector<Token> embd_inp;
|
||||
if (placeholders.empty()) {
|
||||
// this is unusual, but well-defined
|
||||
std::cerr << __func__ << ": prompt template has no placeholder\n";
|
||||
embd_inp = tokenize(promptCtx, promptTemplate, true);
|
||||
} else {
|
||||
// template: beginning of user prompt
|
||||
const auto &phUser = placeholders[0];
|
||||
std::string userPrefix(phUser.prefix());
|
||||
if (!userPrefix.empty()) {
|
||||
embd_inp = tokenize(promptCtx, userPrefix, true);
|
||||
promptCtx.n_past += embd_inp.size();
|
||||
}
|
||||
|
||||
// user input (shouldn't have special token processing)
|
||||
auto tokens = tokenize(promptCtx, prompt, special);
|
||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
||||
promptCtx.n_past += tokens.size();
|
||||
|
||||
// template: end of user prompt + start of assistant prompt
|
||||
size_t start = phUser.position() + phUser.length();
|
||||
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
|
||||
auto userToAsst = promptTemplate.substr(start, end - start);
|
||||
if (!userToAsst.empty()) {
|
||||
tokens = tokenize(promptCtx, userToAsst, true);
|
||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
||||
promptCtx.n_past += tokens.size();
|
||||
}
|
||||
}
|
||||
|
||||
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
|
||||
|
||||
// decode the user prompt
|
||||
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
|
||||
|
||||
// decode the assistant's reply, either generated or spoofed
|
||||
if (fakeReply == nullptr) {
|
||||
generateResponse(responseCallback, recalculateCallback, promptCtx);
|
||||
} else {
|
||||
embd_inp = tokenize(promptCtx, *fakeReply, false);
|
||||
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
|
||||
}
|
||||
|
||||
// decode the rest of the prompt template
|
||||
if (placeholders.size() >= 2) {
|
||||
// template: end of assistant prompt
|
||||
size_t start = placeholders[1].position() + placeholders[1].length();
|
||||
auto asstSuffix = promptTemplate.substr(start);
|
||||
if (!asstSuffix.empty()) {
|
||||
embd_inp = tokenize(promptCtx, asstSuffix, true);
|
||||
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp) {
|
||||
// save the context size
|
||||
promptCtx.n_ctx = contextLength();
|
||||
|
||||
@@ -69,11 +174,6 @@ void LLModel::prompt(const std::string &prompt,
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
}
|
||||
@@ -94,7 +194,11 @@ void LLModel::prompt(const std::string &prompt,
|
||||
}
|
||||
i = batch_end;
|
||||
}
|
||||
}
|
||||
|
||||
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx) {
|
||||
std::string cachedResponse;
|
||||
std::vector<Token> cachedTokens;
|
||||
std::unordered_set<std::string> reversePrompts
|
||||
@@ -108,11 +212,6 @@ void LLModel::prompt(const std::string &prompt,
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||
}
|
||||
@@ -165,8 +264,9 @@ void LLModel::prompt(const std::string &prompt,
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> LLModel::embedding(const std::string &/*text*/)
|
||||
std::vector<float> LLModel::embedding(const std::string &text)
|
||||
{
|
||||
(void)text;
|
||||
if (!supportsCompletion()) {
|
||||
std::string errorMessage = "ERROR: this model does not support generating embeddings!\n";
|
||||
std::cerr << implementation().modelType() << errorMessage;
|
||||
|
Reference in New Issue
Block a user