mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-04-29 12:14:35 +00:00
fix chat-style prompt templates (#1970)
Also use a new version of Mistral OpenOrca. Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
b8f5c74f40
commit
4fc4d94be4
@ -814,8 +814,10 @@ std::vector<float> Bert::embedding(const std::string &text)
|
||||
return finalEmbeddings;
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> Bert::tokenize(PromptContext &, const std::string &str) const
|
||||
std::vector<LLModel::Token> Bert::tokenize(PromptContext &ctx, const std::string &str, bool special) const
|
||||
{
|
||||
(void)ctx;
|
||||
(void)special;
|
||||
return ::bert_tokenize(d_ptr->ctx, str.c_str());
|
||||
}
|
||||
|
||||
|
@ -33,12 +33,13 @@ private:
|
||||
std::unique_ptr<BertPrivate> d_ptr;
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
|
||||
Token sampleToken(PromptContext &ctx) const override;
|
||||
std::string tokenToString(Token) const override;
|
||||
std::string tokenToString(Token id) const override;
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||
int32_t contextLength() const override;
|
||||
const std::vector<Token>& endTokens() const override;
|
||||
const std::vector<Token> &endTokens() const override;
|
||||
bool shouldAddBOS() const override { return true; }
|
||||
};
|
||||
|
||||
#endif // BERT_H
|
||||
|
@ -737,8 +737,10 @@ size_t GPTJ::restoreState(const uint8_t *src)
|
||||
return gptj_set_state_data(d_ptr->model, &d_ptr->rng, src);
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> GPTJ::tokenize(PromptContext &, const std::string &str) const
|
||||
std::vector<LLModel::Token> GPTJ::tokenize(PromptContext &ctx, const std::string &str, bool special) const
|
||||
{
|
||||
(void)ctx;
|
||||
(void)special;
|
||||
return ::gpt_tokenize(d_ptr->vocab, str);
|
||||
}
|
||||
|
||||
|
@ -30,12 +30,13 @@ private:
|
||||
GPTJPrivate *d_ptr;
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
|
||||
Token sampleToken(PromptContext &ctx) const override;
|
||||
std::string tokenToString(Token) const override;
|
||||
std::string tokenToString(Token id) const override;
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||
int32_t contextLength() const override;
|
||||
const std::vector<Token>& endTokens() const override;
|
||||
const std::vector<Token> &endTokens() const override;
|
||||
bool shouldAddBOS() const override { return false; }
|
||||
};
|
||||
|
||||
#endif // GPTJ_H
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit 7d4ced850548642b9a1740fa25ecdef249fbf47f
|
||||
Subproject commit b61ee89fca2038e9937317a794e28e08391b7888
|
@ -6,38 +6,29 @@
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#if defined(_WIN32) && defined(_MSC_VER)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <io.h>
|
||||
#include <stdio.h>
|
||||
#else
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include <map>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include <llama.h>
|
||||
#include <ggml.h>
|
||||
|
||||
#ifdef GGML_USE_KOMPUTE
|
||||
#include "ggml-kompute.h"
|
||||
#include <ggml-kompute.h>
|
||||
#endif
|
||||
|
||||
using namespace std::string_literals;
|
||||
|
||||
// Maximum supported GGUF version
|
||||
static constexpr int GGUF_VER_MAX = 3;
|
||||
|
||||
namespace {
|
||||
const char *modelType_ = "LLaMA";
|
||||
}
|
||||
static const char * const modelType_ = "LLaMA";
|
||||
|
||||
static bool llama_verbose() {
|
||||
const char* var = getenv("GPT4ALL_VERBOSE_LLAMACPP");
|
||||
@ -96,6 +87,56 @@ static int llama_sample_top_p_top_k(
|
||||
return llama_sample_token(ctx, &candidates_p);
|
||||
}
|
||||
|
||||
std::string get_arch_name(gguf_context *ctx_gguf) {
|
||||
std::string arch_name;
|
||||
const int kid = gguf_find_key(ctx_gguf, "general.architecture");
|
||||
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
|
||||
if (ktype != (GGUF_TYPE_STRING)) {
|
||||
throw std::runtime_error("ERROR: Can't get general architecture from gguf file.");
|
||||
}
|
||||
return gguf_get_val_str(ctx_gguf, kid);
|
||||
}
|
||||
|
||||
static gguf_context *load_gguf(const char *fname) {
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ nullptr,
|
||||
};
|
||||
gguf_context *ctx = gguf_init_from_file(fname, params);
|
||||
if (!ctx) {
|
||||
std::cerr << __func__ << ": gguf_init_from_file failed\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int gguf_ver = gguf_get_version(ctx);
|
||||
if (gguf_ver > GGUF_VER_MAX) {
|
||||
std::cerr << __func__ << ": unsupported gguf version: " << gguf_ver << "\n";
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
static int32_t get_arch_key_u32(std::string const &modelPath, std::string const &archKey) {
|
||||
auto * ctx = load_gguf(modelPath.c_str());
|
||||
auto arch = get_arch_name(ctx);
|
||||
|
||||
int32_t value = -1;
|
||||
if (ctx) {
|
||||
auto key = arch + "." + archKey;
|
||||
int keyidx = gguf_find_key(ctx, key.c_str());
|
||||
if (keyidx != -1) {
|
||||
value = gguf_get_val_u32(ctx, keyidx);
|
||||
} else {
|
||||
std::cerr << __func__ << ": " << key << "not found in " << modelPath << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
gguf_free(ctx);
|
||||
return value;
|
||||
}
|
||||
|
||||
struct LLamaPrivate {
|
||||
const std::string modelPath;
|
||||
bool modelLoaded;
|
||||
@ -148,6 +189,42 @@ size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx, int ngl)
|
||||
return filesize + est_kvcache_size;
|
||||
}
|
||||
|
||||
bool LLamaModel::isModelBlacklisted(const std::string &modelPath) {
|
||||
auto * ctx = load_gguf(modelPath.c_str());
|
||||
if (!ctx) {
|
||||
std::cerr << __func__ << ": failed to load " << modelPath << "\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto get_key = [ctx, &modelPath](const char *name) {
|
||||
int keyidx = gguf_find_key(ctx, name);
|
||||
if (keyidx == -1) {
|
||||
throw std::logic_error(name + " not found in "s + modelPath);
|
||||
}
|
||||
return keyidx;
|
||||
};
|
||||
|
||||
bool res = false;
|
||||
try {
|
||||
std::string name(gguf_get_val_str(ctx, get_key("general.name")));
|
||||
int token_idx = get_key("tokenizer.ggml.tokens");
|
||||
int n_vocab = gguf_get_arr_n(ctx, token_idx);
|
||||
|
||||
// check for known bad models
|
||||
if (name == "open-orca_mistral-7b-openorca"
|
||||
&& n_vocab == 32002
|
||||
&& gguf_get_arr_str(ctx, token_idx, 32000) == "<dummy32000>"s // should be <|im_end|>
|
||||
) {
|
||||
res = true;
|
||||
}
|
||||
} catch (const std::logic_error &e) {
|
||||
std::cerr << __func__ << ": " << e.what() << "\n";
|
||||
}
|
||||
|
||||
gguf_free(ctx);
|
||||
return res;
|
||||
}
|
||||
|
||||
bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
||||
{
|
||||
d_ptr->modelLoaded = false;
|
||||
@ -290,12 +367,13 @@ size_t LLamaModel::restoreState(const uint8_t *src)
|
||||
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str) const
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special) const
|
||||
{
|
||||
const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos(d_ptr->model));
|
||||
std::vector<LLModel::Token> fres(str.size()+4);
|
||||
// TODO(cebtenzzre): we may want to use special=true here to process special tokens
|
||||
auto fres_len = llama_tokenize(d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), useBOS, false);
|
||||
const bool wantBOS = ctx.n_past == 0 && ctx.tokens.empty();
|
||||
const bool useBOS = wantBOS && shouldAddBOS();
|
||||
auto strCat = wantBOS && !special ? " " + str : str; // insert leading space ourselves, llama.cpp fork doesn't anymore
|
||||
std::vector<LLModel::Token> fres(strCat.size()+4);
|
||||
auto fres_len = llama_tokenize(d_ptr->model, strCat.c_str(), strCat.length(), fres.data(), fres.size(), useBOS, special);
|
||||
fres.resize(fres_len);
|
||||
return fres;
|
||||
}
|
||||
@ -349,55 +427,10 @@ const std::vector<LLModel::Token> &LLamaModel::endTokens() const
|
||||
return d_ptr->end_tokens;
|
||||
}
|
||||
|
||||
std::string get_arch_name(gguf_context *ctx_gguf) {
|
||||
std::string arch_name;
|
||||
const int kid = gguf_find_key(ctx_gguf, "general.architecture");
|
||||
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
|
||||
if (ktype != (GGUF_TYPE_STRING)) {
|
||||
throw std::runtime_error("ERROR: Can't get general architecture from gguf file.");
|
||||
}
|
||||
return gguf_get_val_str(ctx_gguf, kid);
|
||||
}
|
||||
|
||||
static gguf_context *load_gguf(const char *fname, std::string &arch) {
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ nullptr,
|
||||
};
|
||||
gguf_context *ctx = gguf_init_from_file(fname, params);
|
||||
if (!ctx) {
|
||||
std::cerr << __func__ << ": gguf_init_from_file failed\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int gguf_ver = gguf_get_version(ctx);
|
||||
if (gguf_ver > GGUF_VER_MAX) {
|
||||
std::cerr << __func__ << ": unsupported gguf version: " << gguf_ver << "\n";
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
arch = get_arch_name(ctx);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
static int32_t get_arch_key_u32(std::string const &modelPath, std::string const &archKey) {
|
||||
std::string arch;
|
||||
auto * ctx = load_gguf(modelPath.c_str(), arch);
|
||||
|
||||
int32_t value = -1;
|
||||
if (ctx) {
|
||||
auto key = arch + "." + archKey;
|
||||
int keyidx = gguf_find_key(ctx, key.c_str());
|
||||
if (keyidx != -1) {
|
||||
value = gguf_get_val_u32(ctx, keyidx);
|
||||
} else {
|
||||
std::cerr << __func__ << ": " << key << "not found in " << modelPath << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
gguf_free(ctx);
|
||||
return value;
|
||||
bool LLamaModel::shouldAddBOS() const
|
||||
{
|
||||
int add_bos = llama_add_bos_token(d_ptr->model);
|
||||
return add_bos != -1 ? bool(add_bos) : llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_SPM;
|
||||
}
|
||||
|
||||
int32_t LLamaModel::maxContextLength(std::string const &modelPath) const
|
||||
@ -513,8 +546,8 @@ DLL_EXPORT const char *get_build_variant() {
|
||||
}
|
||||
|
||||
DLL_EXPORT bool magic_match(const char *fname) {
|
||||
std::string arch;
|
||||
auto * ctx = load_gguf(fname, arch);
|
||||
auto * ctx = load_gguf(fname);
|
||||
auto arch = get_arch_name(ctx);
|
||||
|
||||
bool valid = true;
|
||||
|
||||
|
@ -19,6 +19,7 @@ public:
|
||||
bool supportsEmbedding() const override { return false; }
|
||||
bool supportsCompletion() const override { return true; }
|
||||
bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override;
|
||||
bool isModelBlacklisted(const std::string &modelPath) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
|
||||
size_t stateSize() const override;
|
||||
@ -27,7 +28,7 @@ public:
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() const override;
|
||||
std::vector<GPUDevice> availableGPUDevices(size_t memoryRequired) const override;
|
||||
bool initializeGPUDevice(size_t memoryRequired, const std::string& name) const override;
|
||||
bool initializeGPUDevice(size_t memoryRequired, const std::string &name) const override;
|
||||
bool initializeGPUDevice(int device, std::string *unavail_reason) const override;
|
||||
bool hasGPUDevice() override;
|
||||
bool usingGPUDevice() override;
|
||||
@ -36,12 +37,13 @@ private:
|
||||
std::unique_ptr<LLamaPrivate> d_ptr;
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||
std::string tokenToString(Token) const override;
|
||||
Token sampleToken(PromptContext& ctx) const override;
|
||||
bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override;
|
||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
|
||||
std::string tokenToString(Token id) const override;
|
||||
Token sampleToken(PromptContext &ctx) const override;
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||
int32_t contextLength() const override;
|
||||
const std::vector<Token>& endTokens() const override;
|
||||
const std::vector<Token> &endTokens() const override;
|
||||
bool shouldAddBOS() const override;
|
||||
|
||||
int32_t maxContextLength(std::string const &modelPath) const override;
|
||||
int32_t layerCount(std::string const &modelPath) const override;
|
||||
|
@ -29,23 +29,23 @@ public:
|
||||
|
||||
class Implementation {
|
||||
public:
|
||||
Implementation(Dlhandle&&);
|
||||
Implementation(const Implementation&) = delete;
|
||||
Implementation(Implementation&&);
|
||||
Implementation(Dlhandle &&);
|
||||
Implementation(const Implementation &) = delete;
|
||||
Implementation(Implementation &&);
|
||||
~Implementation();
|
||||
|
||||
std::string_view modelType() const { return m_modelType; }
|
||||
std::string_view buildVariant() const { return m_buildVariant; }
|
||||
|
||||
static bool isImplementation(const Dlhandle&);
|
||||
static const std::vector<Implementation>& implementationList();
|
||||
static const Implementation *implementation(const char *fname, const std::string& buildVariant);
|
||||
static bool isImplementation(const Dlhandle &dl);
|
||||
static const std::vector<Implementation> &implementationList();
|
||||
static const Implementation *implementation(const char *fname, const std::string &buildVariant);
|
||||
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto", int n_ctx = 2048);
|
||||
static std::vector<GPUDevice> availableGPUDevices();
|
||||
static int32_t maxContextLength(const std::string &modelPath);
|
||||
static int32_t layerCount(const std::string &modelPath);
|
||||
static void setImplementationsSearchPath(const std::string& path);
|
||||
static const std::string& implementationsSearchPath();
|
||||
static void setImplementationsSearchPath(const std::string &path);
|
||||
static const std::string &implementationsSearchPath();
|
||||
|
||||
private:
|
||||
static LLModel *constructDefaultLlama();
|
||||
@ -82,26 +82,30 @@ public:
|
||||
virtual bool supportsEmbedding() const = 0;
|
||||
virtual bool supportsCompletion() const = 0;
|
||||
virtual bool loadModel(const std::string &modelPath, int n_ctx, int ngl) = 0;
|
||||
virtual bool isModelBlacklisted(const std::string &modelPath) { (void)modelPath; return false; };
|
||||
virtual bool isModelLoaded() const = 0;
|
||||
virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0;
|
||||
virtual size_t stateSize() const { return 0; }
|
||||
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
|
||||
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }
|
||||
virtual size_t saveState(uint8_t *dest) const { (void)dest; return 0; }
|
||||
virtual size_t restoreState(const uint8_t *src) { (void)src; return 0; }
|
||||
|
||||
// This method requires the model to return true from supportsCompletion otherwise it will throw
|
||||
// an error
|
||||
virtual void prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx);
|
||||
PromptContext &ctx,
|
||||
bool special = false,
|
||||
std::string *fakeReply = nullptr);
|
||||
|
||||
virtual std::vector<float> embedding(const std::string &text);
|
||||
|
||||
virtual void setThreadCount(int32_t /*n_threads*/) {}
|
||||
virtual void setThreadCount(int32_t n_threads) { (void)n_threads; }
|
||||
virtual int32_t threadCount() const { return 1; }
|
||||
|
||||
const Implementation& implementation() const {
|
||||
const Implementation &implementation() const {
|
||||
return *m_implementation;
|
||||
}
|
||||
|
||||
@ -110,7 +114,7 @@ public:
|
||||
return {};
|
||||
}
|
||||
|
||||
virtual bool initializeGPUDevice(size_t memoryRequired, const std::string& name) const {
|
||||
virtual bool initializeGPUDevice(size_t memoryRequired, const std::string &name) const {
|
||||
(void)memoryRequired;
|
||||
(void)name;
|
||||
return false;
|
||||
@ -132,12 +136,13 @@ public:
|
||||
protected:
|
||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||
// 'prompt' above calls these functions
|
||||
virtual std::vector<Token> tokenize(PromptContext &, const std::string&) const = 0;
|
||||
virtual std::string tokenToString(Token) const = 0;
|
||||
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) const = 0;
|
||||
virtual std::string tokenToString(Token id) const = 0;
|
||||
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
||||
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;
|
||||
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
|
||||
virtual int32_t contextLength() const = 0;
|
||||
virtual const std::vector<Token>& endTokens() const = 0;
|
||||
virtual const std::vector<Token> &endTokens() const = 0;
|
||||
virtual bool shouldAddBOS() const = 0;
|
||||
|
||||
virtual int32_t maxContextLength(std::string const &modelPath) const
|
||||
{
|
||||
@ -166,6 +171,15 @@ protected:
|
||||
return true;
|
||||
}
|
||||
|
||||
void decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp);
|
||||
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx);
|
||||
|
||||
private:
|
||||
friend class LLMImplementation;
|
||||
};
|
||||
|
@ -1,8 +1,9 @@
|
||||
#include "llmodel_c.h"
|
||||
#include "llmodel.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <cerrno>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
|
||||
struct LLModelWrapper {
|
||||
@ -56,7 +57,14 @@ size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_c
|
||||
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx, int ngl)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->loadModel(model_path, n_ctx, ngl);
|
||||
|
||||
std::string modelPath(model_path);
|
||||
if (wrapper->llModel->isModelBlacklisted(modelPath)) {
|
||||
size_t slash = modelPath.find_last_of("/\\");
|
||||
auto basename = slash == std::string::npos ? modelPath : modelPath.substr(slash + 1);
|
||||
std::cerr << "warning: model '" << basename << "' is out-of-date, please check for an updated version\n";
|
||||
}
|
||||
return wrapper->llModel->loadModel(modelPath, n_ctx, ngl);
|
||||
}
|
||||
|
||||
bool llmodel_isModelLoaded(llmodel_model model)
|
||||
@ -100,10 +108,12 @@ bool recalculate_wrapper(bool is_recalculating, void *user_data) {
|
||||
}
|
||||
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
const char *prompt_template,
|
||||
llmodel_prompt_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
llmodel_recalculate_callback recalculate_callback,
|
||||
llmodel_prompt_context *ctx)
|
||||
llmodel_prompt_context *ctx,
|
||||
bool special)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
|
||||
@ -131,7 +141,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
wrapper->promptContext.contextErase = ctx->context_erase;
|
||||
|
||||
// Call the C++ prompt method
|
||||
wrapper->llModel->prompt(prompt, prompt_func, response_func, recalc_func, wrapper->promptContext);
|
||||
wrapper->llModel->prompt(prompt, prompt_template, prompt_func, response_func, recalc_func, wrapper->promptContext, special);
|
||||
|
||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||
// which involves no copies
|
||||
|
@ -163,16 +163,20 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
|
||||
* Generate a response using the model.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param prompt A string representing the input prompt.
|
||||
* @param prompt_template A string representing the input prompt template.
|
||||
* @param prompt_callback A callback function for handling the processing of prompt.
|
||||
* @param response_callback A callback function for handling the generated response.
|
||||
* @param recalculate_callback A callback function for handling recalculation requests.
|
||||
* @param special True if special tokens in the prompt should be processed, false otherwise.
|
||||
* @param ctx A pointer to the llmodel_prompt_context structure.
|
||||
*/
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
const char *prompt_template,
|
||||
llmodel_prompt_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
llmodel_recalculate_callback recalculate_callback,
|
||||
llmodel_prompt_context *ctx);
|
||||
llmodel_prompt_context *ctx,
|
||||
bool special);
|
||||
|
||||
/**
|
||||
* Generate an embedding using the model.
|
||||
|
@ -2,11 +2,20 @@
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
#include <unordered_set>
|
||||
|
||||
// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is)
|
||||
void LLModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate) {
|
||||
size_t i = 0;
|
||||
promptCtx.n_past = 0;
|
||||
int n_keep = shouldAddBOS();
|
||||
const int32_t n_discard = (promptCtx.n_ctx - n_keep) * promptCtx.contextErase;
|
||||
|
||||
// Erase the first percentage of context from the tokens
|
||||
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
|
||||
|
||||
size_t i = n_keep;
|
||||
promptCtx.n_past = n_keep;
|
||||
while (i < promptCtx.tokens.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||
std::vector<int32_t> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||
@ -26,11 +35,36 @@ stop_generating:
|
||||
recalculate(false);
|
||||
}
|
||||
|
||||
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err) {
|
||||
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))");
|
||||
|
||||
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex);
|
||||
placeholders.clear();
|
||||
placeholders.insert(placeholders.end(), it, std::sregex_iterator());
|
||||
|
||||
if (placeholders.size() > 2) {
|
||||
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size());
|
||||
return false;
|
||||
}
|
||||
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
|
||||
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
|
||||
return false;
|
||||
}
|
||||
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
|
||||
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void LLModel::prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx)
|
||||
PromptContext &promptCtx,
|
||||
bool special,
|
||||
std::string *fakeReply)
|
||||
{
|
||||
if (!isModelLoaded()) {
|
||||
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
|
||||
@ -38,15 +72,86 @@ void LLModel::prompt(const std::string &prompt,
|
||||
}
|
||||
|
||||
if (!supportsCompletion()) {
|
||||
std::string errorMessage = "ERROR: this model does not support text completion or chat!\n";
|
||||
std::string errorMessage = "ERROR: this model does not support text completion or chat!";
|
||||
responseCallback(-1, errorMessage);
|
||||
std::cerr << implementation().modelType() << errorMessage;
|
||||
std::cerr << implementation().modelType() << " " << errorMessage << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<Token> embd_inp = tokenize(promptCtx, prompt);
|
||||
// parse the prompt template
|
||||
std::vector<std::smatch> placeholders;
|
||||
{
|
||||
std::string err;
|
||||
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
|
||||
responseCallback(-1, err);
|
||||
std::cerr << err << "\n";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
|
||||
|
||||
// tokenize the user prompt
|
||||
std::vector<Token> embd_inp;
|
||||
if (placeholders.empty()) {
|
||||
// this is unusual, but well-defined
|
||||
std::cerr << __func__ << ": prompt template has no placeholder\n";
|
||||
embd_inp = tokenize(promptCtx, promptTemplate, true);
|
||||
} else {
|
||||
// template: beginning of user prompt
|
||||
const auto &phUser = placeholders[0];
|
||||
std::string userPrefix(phUser.prefix());
|
||||
if (!userPrefix.empty()) {
|
||||
embd_inp = tokenize(promptCtx, userPrefix, true);
|
||||
promptCtx.n_past += embd_inp.size();
|
||||
}
|
||||
|
||||
// user input (shouldn't have special token processing)
|
||||
auto tokens = tokenize(promptCtx, prompt, special);
|
||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
||||
promptCtx.n_past += tokens.size();
|
||||
|
||||
// template: end of user prompt + start of assistant prompt
|
||||
size_t start = phUser.position() + phUser.length();
|
||||
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
|
||||
auto userToAsst = promptTemplate.substr(start, end - start);
|
||||
if (!userToAsst.empty()) {
|
||||
tokens = tokenize(promptCtx, userToAsst, true);
|
||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
||||
promptCtx.n_past += tokens.size();
|
||||
}
|
||||
}
|
||||
|
||||
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
|
||||
|
||||
// decode the user prompt
|
||||
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
|
||||
|
||||
// decode the assistant's reply, either generated or spoofed
|
||||
if (fakeReply == nullptr) {
|
||||
generateResponse(responseCallback, recalculateCallback, promptCtx);
|
||||
} else {
|
||||
embd_inp = tokenize(promptCtx, *fakeReply, false);
|
||||
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
|
||||
}
|
||||
|
||||
// decode the rest of the prompt template
|
||||
if (placeholders.size() >= 2) {
|
||||
// template: end of assistant prompt
|
||||
size_t start = placeholders[1].position() + placeholders[1].length();
|
||||
auto asstSuffix = promptTemplate.substr(start);
|
||||
if (!asstSuffix.empty()) {
|
||||
embd_inp = tokenize(promptCtx, asstSuffix, true);
|
||||
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp) {
|
||||
// save the context size
|
||||
promptCtx.n_ctx = contextLength();
|
||||
|
||||
@ -69,11 +174,6 @@ void LLModel::prompt(const std::string &prompt,
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
}
|
||||
@ -94,7 +194,11 @@ void LLModel::prompt(const std::string &prompt,
|
||||
}
|
||||
i = batch_end;
|
||||
}
|
||||
}
|
||||
|
||||
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx) {
|
||||
std::string cachedResponse;
|
||||
std::vector<Token> cachedTokens;
|
||||
std::unordered_set<std::string> reversePrompts
|
||||
@ -108,11 +212,6 @@ void LLModel::prompt(const std::string &prompt,
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||
}
|
||||
@ -165,8 +264,9 @@ void LLModel::prompt(const std::string &prompt,
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> LLModel::embedding(const std::string &/*text*/)
|
||||
std::vector<float> LLModel::embedding(const std::string &text)
|
||||
{
|
||||
(void)text;
|
||||
if (!supportsCompletion()) {
|
||||
std::string errorMessage = "ERROR: this model does not support generating embeddings!\n";
|
||||
std::cerr << implementation().modelType() << errorMessage;
|
||||
|
@ -246,90 +246,6 @@ To do the same outside a session, the input has to be formatted manually. For ex
|
||||
The colors in my previous response are blue, green and red.
|
||||
```
|
||||
|
||||
Ultimately, the method `GPT4All._format_chat_prompt_template()` is responsible for formatting templates. It can be
|
||||
customized in a subclass. As an example:
|
||||
|
||||
=== "Custom Subclass"
|
||||
``` py
|
||||
from itertools import cycle
|
||||
from gpt4all import GPT4All
|
||||
|
||||
class RotatingTemplateGPT4All(GPT4All):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._templates = [
|
||||
"Respond like a pirate.",
|
||||
"Respond like a politician.",
|
||||
"Respond like a philosopher.",
|
||||
"Respond like a Klingon.",
|
||||
]
|
||||
self._cycling_templates = cycle(self._templates)
|
||||
|
||||
def _format_chat_prompt_template(
|
||||
self,
|
||||
messages: list,
|
||||
default_prompt_header: str = "",
|
||||
default_prompt_footer: str = "",
|
||||
) -> str:
|
||||
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
user_message = f"USER: {message['content']} {next(self._cycling_templates)}\n"
|
||||
full_prompt += user_message
|
||||
if message["role"] == "assistant":
|
||||
assistant_message = f"ASSISTANT: {message['content']}\n"
|
||||
full_prompt += assistant_message
|
||||
full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else ""
|
||||
print(full_prompt)
|
||||
return full_prompt
|
||||
```
|
||||
=== "GPT4All Custom Subclass Example"
|
||||
``` py
|
||||
model = RotatingTemplateGPT4All('wizardlm-13b-v1.2.Q4_0.gguf')
|
||||
with model.chat_session(): # starting a session is optional in this example
|
||||
response1 = model.generate("hi, who are you?")
|
||||
print(response1)
|
||||
print()
|
||||
response2 = model.generate("what can you tell me about snakes?")
|
||||
print(response2)
|
||||
print()
|
||||
response3 = model.generate("what's your opinion on Chess?")
|
||||
print(response3)
|
||||
print()
|
||||
response4 = model.generate("tell me about ancient Rome.")
|
||||
print(response4)
|
||||
```
|
||||
=== "Possible Output"
|
||||
```
|
||||
USER: hi, who are you? Respond like a pirate.
|
||||
|
||||
Pirate: Ahoy there mateys! I be Cap'n Jack Sparrow of the Black Pearl.
|
||||
|
||||
USER: what can you tell me about snakes? Respond like a politician.
|
||||
|
||||
Politician: Snakes have been making headlines lately due to their ability to
|
||||
slither into tight spaces and evade capture, much like myself during my last
|
||||
election campaign. However, I believe that with proper education and
|
||||
understanding of these creatures, we can work together towards creating a
|
||||
safer environment for both humans and snakes alike.
|
||||
|
||||
USER: what's your opinion on Chess? Respond like a philosopher.
|
||||
|
||||
Philosopher: The game of chess is often used as an analogy to illustrate the
|
||||
complexities of life and decision-making processes. However, I believe that it
|
||||
can also be seen as a reflection of our own consciousness and subconscious mind.
|
||||
Just as each piece on the board has its unique role to play in shaping the
|
||||
outcome of the game, we too have different roles to fulfill in creating our own
|
||||
personal narrative.
|
||||
|
||||
USER: tell me about ancient Rome. Respond like a Klingon.
|
||||
|
||||
Klingon: Ancient Rome was once a great empire that ruled over much of Europe and
|
||||
the Mediterranean region. However, just as the Empire fell due to internal strife
|
||||
and external threats, so too did my own house come crashing down when I failed to
|
||||
protect our homeworld from invading forces.
|
||||
```
|
||||
|
||||
|
||||
### Introspection
|
||||
A less apparent feature is the capacity to log the final prompt that gets sent to the model. It relies on
|
||||
|
@ -89,10 +89,12 @@ RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
||||
llmodel.llmodel_prompt.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_char_p,
|
||||
PromptCallback,
|
||||
ResponseCallback,
|
||||
RecalculateCallback,
|
||||
ctypes.POINTER(LLModelPromptContext),
|
||||
ctypes.c_bool,
|
||||
]
|
||||
|
||||
llmodel.llmodel_prompt.restype = None
|
||||
@ -290,6 +292,7 @@ class LLModel:
|
||||
def prompt_model(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: str,
|
||||
callback: ResponseCallbackType,
|
||||
n_predict: int = 4096,
|
||||
top_k: int = 40,
|
||||
@ -300,6 +303,7 @@ class LLModel:
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = 0.75,
|
||||
reset_context: bool = False,
|
||||
special: bool = False,
|
||||
):
|
||||
"""
|
||||
Generate response from model from a prompt.
|
||||
@ -326,9 +330,6 @@ class LLModel:
|
||||
prompt,
|
||||
)
|
||||
|
||||
prompt_bytes = prompt.encode()
|
||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||
|
||||
self._set_context(
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
@ -343,16 +344,18 @@ class LLModel:
|
||||
|
||||
llmodel.llmodel_prompt(
|
||||
self.model,
|
||||
prompt_ptr,
|
||||
ctypes.c_char_p(prompt.encode()),
|
||||
ctypes.c_char_p(prompt_template.encode()),
|
||||
PromptCallback(self._prompt_callback),
|
||||
ResponseCallback(self._callback_decoder(callback)),
|
||||
RecalculateCallback(self._recalculate_callback),
|
||||
self.context,
|
||||
special,
|
||||
)
|
||||
|
||||
|
||||
def prompt_model_streaming(
|
||||
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
||||
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
||||
) -> Iterable[str]:
|
||||
output_queue: Queue[str | Sentinel] = Queue()
|
||||
|
||||
@ -369,15 +372,15 @@ class LLModel:
|
||||
|
||||
return _generator_callback
|
||||
|
||||
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
|
||||
self.prompt_model(prompt, callback, **kwargs)
|
||||
def run_llmodel_prompt(prompt: str, prompt_template: str, callback: ResponseCallbackType, **kwargs):
|
||||
self.prompt_model(prompt, prompt_template, callback, **kwargs)
|
||||
output_queue.put(Sentinel.TERMINATING_SYMBOL)
|
||||
|
||||
# Kick off llmodel_prompt in separate thread so we can return generator
|
||||
# immediately
|
||||
thread = threading.Thread(
|
||||
target=run_llmodel_prompt,
|
||||
args=(prompt, _generator_callback_wrapper(callback)),
|
||||
args=(prompt, prompt_template, _generator_callback_wrapper(callback)),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
thread.start()
|
||||
|
@ -4,8 +4,10 @@ Python only API for running all GPT4All models.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
@ -314,6 +316,10 @@ class GPT4All:
|
||||
Either the entire completion or a generator that yields the completion token by token.
|
||||
"""
|
||||
|
||||
if re.search(r"%1(?![0-9])", self._current_prompt_template):
|
||||
raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt "
|
||||
"placeholder, please use '{0}' instead.")
|
||||
|
||||
# Preparing the model request
|
||||
generate_kwargs: Dict[str, Any] = dict(
|
||||
temp=temp,
|
||||
@ -327,16 +333,29 @@ class GPT4All:
|
||||
|
||||
if self._is_chat_session_activated:
|
||||
# check if there is only one message, i.e. system prompt:
|
||||
generate_kwargs["reset_context"] = len(self.current_chat_session) == 1
|
||||
reset = len(self.current_chat_session) == 1
|
||||
generate_kwargs["reset_context"] = reset
|
||||
self.current_chat_session.append({"role": "user", "content": prompt})
|
||||
|
||||
prompt = self._format_chat_prompt_template(
|
||||
messages=self.current_chat_session[-1:],
|
||||
default_prompt_header=self.current_chat_session[0]["content"]
|
||||
if generate_kwargs["reset_context"]
|
||||
else "",
|
||||
)
|
||||
if self._format_chat_prompt_template.__func__ is GPT4All._format_chat_prompt_template:
|
||||
if reset:
|
||||
# ingest system prompt
|
||||
self.model.prompt_model(self.current_chat_session[0]["content"], "%1",
|
||||
n_batch=n_batch, n_predict=0, special=True)
|
||||
prompt_template = self._current_prompt_template.format("%1")
|
||||
else:
|
||||
warnings.warn(
|
||||
"_format_chat_prompt_template is deprecated. Please use a chat session with a prompt template.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
# special tokens won't be processed
|
||||
prompt = self._format_chat_prompt_template(
|
||||
self.current_chat_session[-1:],
|
||||
self.current_chat_session[0]["content"] if reset else "",
|
||||
)
|
||||
prompt_template = "%1"
|
||||
else:
|
||||
prompt_template = "%1"
|
||||
generate_kwargs["reset_context"] = True
|
||||
|
||||
# Prepare the callback, process the model response
|
||||
@ -365,14 +384,16 @@ class GPT4All:
|
||||
# Send the request to the model
|
||||
if streaming:
|
||||
return self.model.prompt_model_streaming(
|
||||
prompt=prompt,
|
||||
callback=_callback_wrapper(callback, output_collector),
|
||||
prompt,
|
||||
prompt_template,
|
||||
_callback_wrapper(callback, output_collector),
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
self.model.prompt_model(
|
||||
prompt=prompt,
|
||||
callback=_callback_wrapper(callback, output_collector),
|
||||
prompt,
|
||||
prompt_template,
|
||||
_callback_wrapper(callback, output_collector),
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
@ -423,24 +444,6 @@ class GPT4All:
|
||||
Formatted prompt.
|
||||
"""
|
||||
|
||||
if isinstance(default_prompt_header, bool):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Using True/False for the 'default_prompt_header' is deprecated. Use a string instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
default_prompt_header = ""
|
||||
|
||||
if isinstance(default_prompt_footer, bool):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Using True/False for the 'default_prompt_footer' is deprecated. Use a string instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
default_prompt_footer = ""
|
||||
|
||||
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
|
||||
|
||||
for message in messages:
|
||||
|
@ -68,7 +68,7 @@ def get_long_description():
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version="2.2.1.post1",
|
||||
version="2.3.0",
|
||||
description="Python bindings for GPT4All",
|
||||
long_description=get_long_description(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
@ -75,13 +75,18 @@ size_t ChatGPT::restoreState(const uint8_t *src)
|
||||
}
|
||||
|
||||
void ChatGPT::prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx) {
|
||||
PromptContext &promptCtx,
|
||||
bool special,
|
||||
std::string *fakeReply) {
|
||||
|
||||
Q_UNUSED(promptCallback);
|
||||
Q_UNUSED(recalculateCallback);
|
||||
Q_UNUSED(special);
|
||||
Q_UNUSED(fakeReply); // FIXME(cebtenzzre): I broke ChatGPT
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
std::cerr << "ChatGPT ERROR: prompt won't work with an unloaded model!\n";
|
||||
@ -109,7 +114,7 @@ void ChatGPT::prompt(const std::string &prompt,
|
||||
|
||||
QJsonObject promptObject;
|
||||
promptObject.insert("role", "user");
|
||||
promptObject.insert("content", QString::fromStdString(prompt));
|
||||
promptObject.insert("content", QString::fromStdString(promptTemplate).arg(QString::fromStdString(prompt)));
|
||||
messages.append(promptObject);
|
||||
root.insert("messages", messages);
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
#ifndef CHATGPT_H
|
||||
#define CHATGPT_H
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include <QObject>
|
||||
#include <QNetworkReply>
|
||||
#include <QNetworkRequest>
|
||||
@ -55,10 +57,13 @@ public:
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
void prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx) override;
|
||||
PromptContext &ctx,
|
||||
bool special,
|
||||
std::string *fakeReply) override;
|
||||
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() const override;
|
||||
@ -69,7 +74,7 @@ public:
|
||||
QList<QString> context() const { return m_context; }
|
||||
void setContext(const QList<QString> &context) { m_context = context; }
|
||||
|
||||
bool callResponse(int32_t token, const std::string& string);
|
||||
bool callResponse(int32_t token, const std::string &string);
|
||||
|
||||
Q_SIGNALS:
|
||||
void request(const QString &apiKey,
|
||||
@ -80,12 +85,41 @@ protected:
|
||||
// We have to implement these as they are pure virtual in base class, but we don't actually use
|
||||
// them as they are only called from the default implementation of 'prompt' which we override and
|
||||
// completely replace
|
||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override { return std::vector<Token>(); }
|
||||
std::string tokenToString(Token) const override { return std::string(); }
|
||||
Token sampleToken(PromptContext &ctx) const override { return -1; }
|
||||
bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; }
|
||||
int32_t contextLength() const override { return -1; }
|
||||
const std::vector<Token>& endTokens() const override { static const std::vector<Token> fres; return fres; }
|
||||
|
||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override {
|
||||
(void)ctx;
|
||||
(void)str;
|
||||
(void)special;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
std::string tokenToString(Token id) const override {
|
||||
(void)id;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
Token sampleToken(PromptContext &ctx) const override {
|
||||
(void)ctx;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override {
|
||||
(void)ctx;
|
||||
(void)tokens;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
int32_t contextLength() const override {
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
const std::vector<Token> &endTokens() const override {
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
bool shouldAddBOS() const override {
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<bool(int32_t, const std::string&)> m_responseCallback;
|
||||
|
@ -303,6 +303,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
|
||||
|
||||
if (m_llModelInfo.model) {
|
||||
if (m_llModelInfo.model->isModelBlacklisted(filePath.toStdString())) {
|
||||
// TODO(cebtenzzre): warn that this model is out-of-date
|
||||
}
|
||||
|
||||
m_llModelInfo.model->setProgressCallback([this](float progress) -> bool {
|
||||
emit modelLoadingPercentageChanged(progress);
|
||||
@ -588,14 +591,11 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
}
|
||||
|
||||
// Augment the prompt template with the results if any
|
||||
QList<QString> augmentedTemplate;
|
||||
QList<QString> docsContext;
|
||||
if (!databaseResults.isEmpty())
|
||||
augmentedTemplate.append("### Context:");
|
||||
docsContext.append("### Context:");
|
||||
for (const ResultInfo &info : databaseResults)
|
||||
augmentedTemplate.append(info.text);
|
||||
augmentedTemplate.append(promptTemplate);
|
||||
|
||||
QString instructPrompt = augmentedTemplate.join("\n").arg(prompt);
|
||||
docsContext.append(info.text);
|
||||
|
||||
int n_threads = MySettings::globalInstance()->threadCount();
|
||||
|
||||
@ -605,7 +605,6 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
|
||||
emit promptProcessing();
|
||||
qint32 logitsBefore = m_ctx.logits.size();
|
||||
m_ctx.n_predict = n_predict;
|
||||
m_ctx.top_k = top_k;
|
||||
m_ctx.top_p = top_p;
|
||||
@ -615,11 +614,16 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
m_llModelInfo.model->setThreadCount(n_threads);
|
||||
#if defined(DEBUG)
|
||||
printf("%s", qPrintable(instructPrompt));
|
||||
printf("%s", qPrintable(prompt));
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_timer->start();
|
||||
m_llModelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
if (!docsContext.isEmpty()) {
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
|
||||
m_llModelInfo.model->prompt(docsContext.join("\n").toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
m_ctx.n_predict = old_n_predict; // now we are ready for a response
|
||||
}
|
||||
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
@ -720,7 +724,7 @@ void ChatLLM::generateName()
|
||||
printf("%s", qPrintable(instructPrompt));
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_llModelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
|
||||
m_llModelInfo.model->prompt(instructPrompt.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, ctx);
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
@ -780,16 +784,6 @@ bool ChatLLM::handleSystemPrompt(int32_t token)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleSystemResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "system response" << m_llmThread.objectName() << token << response << m_stopGenerating;
|
||||
#endif
|
||||
Q_UNUSED(token);
|
||||
Q_UNUSED(response);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleSystemRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
@ -808,16 +802,6 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRestoreStateFromTextResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "restore state from text response" << m_llmThread.objectName() << token << response << m_stopGenerating;
|
||||
#endif
|
||||
Q_UNUSED(token);
|
||||
Q_UNUSED(response);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
@ -1027,8 +1011,6 @@ void ChatLLM::processSystemPrompt()
|
||||
m_ctx = LLModel::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleSystemResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleSystemRecalculate, this, std::placeholders::_1);
|
||||
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
@ -1051,7 +1033,9 @@ void ChatLLM::processSystemPrompt()
|
||||
printf("%s", qPrintable(QString::fromStdString(systemPrompt)));
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_llModelInfo.model->prompt(systemPrompt, promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
|
||||
m_llModelInfo.model->prompt(systemPrompt, "%1", promptFunc, nullptr, recalcFunc, m_ctx, true);
|
||||
m_ctx.n_predict = old_n_predict;
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
@ -1073,8 +1057,6 @@ void ChatLLM::processRestoreStateFromText()
|
||||
m_ctx = LLModel::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleRestoreStateFromTextResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
|
||||
|
||||
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
@ -1094,9 +1076,19 @@ void ChatLLM::processRestoreStateFromText()
|
||||
m_ctx.repeat_penalty = repeat_penalty;
|
||||
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
m_llModelInfo.model->setThreadCount(n_threads);
|
||||
for (auto pair : m_stateFromText) {
|
||||
const QString str = pair.first == "Prompt: " ? promptTemplate.arg(pair.second) : pair.second;
|
||||
m_llModelInfo.model->prompt(str.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
|
||||
auto it = m_stateFromText.begin();
|
||||
while (it < m_stateFromText.end()) {
|
||||
auto &prompt = *it++;
|
||||
Q_ASSERT(prompt.first == "Prompt: ");
|
||||
Q_ASSERT(it < m_stateFromText.end());
|
||||
|
||||
auto &response = *it++;
|
||||
Q_ASSERT(response.first != "Prompt: ");
|
||||
auto responseText = response.second.toStdString();
|
||||
|
||||
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
|
||||
recalcFunc, m_ctx, false, &responseText);
|
||||
}
|
||||
|
||||
if (!m_stopGenerating) {
|
||||
|
@ -17,10 +17,10 @@
|
||||
},
|
||||
{
|
||||
"order": "b",
|
||||
"md5sum": "48de9538c774188eb25a7e9ee024bbd3",
|
||||
"md5sum": "f692417a22405d80573ac10cb0cd6c6a",
|
||||
"name": "Mistral OpenOrca",
|
||||
"filename": "mistral-7b-openorca.Q4_0.gguf",
|
||||
"filesize": "4108927744",
|
||||
"filename": "mistral-7b-openorca.Q4_0.gguf2.gguf",
|
||||
"filesize": "4108928128",
|
||||
"requires": "2.5.0",
|
||||
"ramrequired": "8",
|
||||
"parameters": "7 billion",
|
||||
@ -28,7 +28,7 @@
|
||||
"type": "Mistral",
|
||||
"description": "<strong>Best overall fast chat model</strong><br><ul><li>Fast responses</li><li>Chat based model</li><li>Trained by Mistral AI<li>Finetuned on OpenOrca dataset curated via <a href=\"https://atlas.nomic.ai/\">Nomic Atlas</a><li>Licensed for commercial use</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/mistral-7b-openorca.Q4_0.gguf",
|
||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|><|im_start|>assistant\n",
|
||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n",
|
||||
"systemPrompt": "<|im_start|>system\nYou are MistralOrca, a large language model trained by Alignment Lab AI. For multi-step problems, write out your reasoning for each step.\n<|im_end|>"
|
||||
},
|
||||
{
|
||||
@ -152,7 +152,7 @@
|
||||
"type": "MPT",
|
||||
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat-newbpe-q4_0.gguf",
|
||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|><|im_start|>assistant\n",
|
||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n",
|
||||
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>"
|
||||
},
|
||||
{
|
||||
|
@ -951,7 +951,7 @@ void ModelList::updateModelsFromDirectory()
|
||||
processDirectory(localPath);
|
||||
}
|
||||
|
||||
#define MODELS_VERSION 2
|
||||
#define MODELS_VERSION 3
|
||||
|
||||
void ModelList::updateModelsFromJson()
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user