1
0
mirror of https://github.com/nomic-ai/gpt4all.git synced 2025-05-05 06:57:15 +00:00

Remove binary state from high-level API and use Jinja templates ()

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Signed-off-by: Adam Treat <treat.adam@gmail.com>
Co-authored-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Jared Van Bortel 2024-11-25 10:04:17 -05:00 committed by GitHub
parent 3320094d29
commit 225bf6be93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
54 changed files with 3423 additions and 2224 deletions

3
.gitmodules vendored
View File

@ -17,3 +17,6 @@
[submodule "gpt4all-chat/deps/QXlsx"]
path = gpt4all-chat/deps/QXlsx
url = https://github.com/nomic-ai/QXlsx.git
[submodule "gpt4all-chat/deps/Jinja2Cpp"]
path = gpt4all-chat/deps/Jinja2Cpp
url = https://github.com/nomic-ai/jinja2cpp.git

View File

@ -5,6 +5,7 @@
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <expected>
#include <functional>
#include <optional>
#include <span>
@ -24,6 +25,10 @@ using namespace std::string_literals;
class LLModel {
public:
using Token = int32_t;
using PromptCallback = std::function<bool(std::span<const Token> batch, bool cached)>;
using ResponseCallback = std::function<bool(Token token, std::string_view piece)>;
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
using ProgressCallback = std::function<bool(float progress)>;
class BadArchError: public std::runtime_error {
public:
@ -101,6 +106,7 @@ public:
static int32_t maxContextLength(const std::string &modelPath);
static int32_t layerCount(const std::string &modelPath);
static bool isEmbeddingModel(const std::string &modelPath);
static auto chatTemplate(const char *modelPath) -> std::expected<std::string, std::string>;
static void setImplementationsSearchPath(const std::string &path);
static const std::string &implementationsSearchPath();
static bool hasSupportedCPU();
@ -124,7 +130,6 @@ public:
};
struct PromptContext {
int32_t n_past = 0; // number of tokens in past conversation
int32_t n_predict = 200;
int32_t top_k = 40;
float top_p = 0.9f;
@ -136,8 +141,6 @@ public:
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
};
using ProgressCallback = std::function<bool(float progress)>;
explicit LLModel() {}
virtual ~LLModel() {}
@ -154,16 +157,12 @@ public:
// This method requires the model to return true from supportsCompletion otherwise it will throw
// an error
virtual void prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &ctx,
bool special = false,
std::optional<std::string_view> fakeReply = {});
virtual void prompt(std::string_view prompt,
const PromptCallback &promptCallback,
const ResponseCallback &responseCallback,
const PromptContext &ctx);
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
virtual int32_t countPromptTokens(std::string_view prompt) const;
virtual size_t embeddingSize() const {
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
@ -209,23 +208,22 @@ public:
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }
virtual int32_t contextLength() const = 0;
virtual auto specialTokens() -> std::unordered_map<std::string, std::string> const = 0;
protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(std::string_view str, bool special = false) = 0;
virtual std::vector<Token> tokenize(std::string_view str) const = 0;
virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual void initSampler(PromptContext &ctx) = 0;
virtual void initSampler(const PromptContext &ctx) = 0;
virtual Token sampleToken() const = 0;
virtual bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0;
virtual bool evalTokens(int32_t nPast, std::span<const Token> tokens) const = 0;
virtual void shiftContext(const PromptContext &promptCtx, int32_t *nPast) = 0;
virtual int32_t inputLength() const = 0;
virtual void setTokenizeInputPosition(int32_t pos) = 0;
virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator = 0;
virtual void setModelInputPosition(PromptContext &ctx, int32_t pos) = 0;
virtual void appendInputToken(PromptContext &ctx, Token tok) = 0;
virtual int32_t computeModelInputPosition(std::span<const Token> input) const = 0;
virtual void setModelInputPosition(int32_t pos) = 0;
virtual void appendInputToken(Token tok) = 0;
virtual std::span<const Token> inputTokens() const = 0;
virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0;
@ -242,6 +240,12 @@ protected:
return -1;
}
virtual auto chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string>
{
(void)modelPath;
return std::unexpected("not implemented");
}
const Implementation *m_implementation = nullptr;
ProgressCallback m_progressCallback;
@ -253,19 +257,15 @@ protected:
return true;
}
bool decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp,
bool isResponse = false,
bool alwaysDecode = false);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx);
protected:
Token m_tokenize_last_token = -1; // not serialized
// prefill context with prompt
auto decodePrompt(const PromptCallback &promptCallback,
const PromptContext &promptCtx,
std::vector<Token> embd_inp)
-> std::optional<int32_t>;
// generate a response
void generateResponse(const ResponseCallback &responseCallback,
const PromptContext &promptCtx,
int32_t nPast);
friend class LLMImplementation;
};

View File

@ -35,16 +35,15 @@ typedef int32_t token_t;
* behavior.
*/
struct llmodel_prompt_context {
int32_t n_past; // number of tokens in past conversation
int32_t n_predict; // number of tokens to predict
int32_t top_k; // top k logits to sample from
float top_p; // nucleus sampling probability threshold
float min_p; // Min P sampling
float temp; // temperature to adjust model's output distribution
float top_p; // nucleus sampling probability threshold
float min_p; // Min P sampling
float temp; // temperature to adjust model's output distribution
int32_t n_batch; // number of predictions to generate in parallel
float repeat_penalty; // penalty factor for repeated tokens
float repeat_penalty; // penalty factor for repeated tokens
int32_t repeat_last_n; // last n tokens to penalize
float context_erase; // percent of context to erase if we exceed the context window
float context_erase; // percent of context to erase if we exceed the context window
};
struct llmodel_gpu_device {
@ -63,10 +62,12 @@ typedef struct llmodel_gpu_device llmodel_gpu_device;
/**
* Callback type for prompt processing.
* @param token_id The token id of the prompt.
* @param token_ids An array of token ids of the prompt.
* @param n_token_ids The number of tokens in the array.
* @param cached Whether the tokens were already in cache.
* @return a bool indicating whether the model should keep processing.
*/
typedef bool (*llmodel_prompt_callback)(int32_t token_id);
typedef bool (*llmodel_prompt_callback)(const token_t *token_ids, size_t n_token_ids, bool cached);
/**
* Callback type for response.
@ -74,7 +75,7 @@ typedef bool (*llmodel_prompt_callback)(int32_t token_id);
* @param response The response string. NOTE: a token_id of -1 indicates the string is an error string.
* @return a bool indicating whether the model should keep generating.
*/
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
typedef bool (*llmodel_response_callback)(token_t token_id, const char *response);
/**
* Embedding cancellation callback for use with llmodel_embed.
@ -85,6 +86,8 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response
*/
typedef bool (*llmodel_emb_cancel_callback)(unsigned *batch_sizes, unsigned n_batch, const char *backend);
typedef void (*llmodel_special_token_callback)(const char *name, const char *token);
/**
* Create a llmodel instance.
* Recognises correct model type from file at model_path
@ -183,22 +186,17 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6
* Generate a response using the model.
* @param model A pointer to the llmodel_model instance.
* @param prompt A string representing the input prompt.
* @param prompt_template A string representing the input prompt template.
* @param prompt_callback A callback function for handling the processing of prompt.
* @param response_callback A callback function for handling the generated response.
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
* @param special True if special tokens in the prompt should be processed, false otherwise.
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
* @param ctx A pointer to the llmodel_prompt_context structure.
* @param error A pointer to a string; will only be set on error.
*/
void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
bool allow_context_shift,
llmodel_prompt_context *ctx,
bool special,
const char *fake_reply);
bool llmodel_prompt(llmodel_model model,
const char *prompt,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_prompt_context *ctx,
const char **error);
/**
* Generate an embedding using the model.
@ -310,6 +308,10 @@ const char *llmodel_model_backend_name(llmodel_model model);
*/
const char *llmodel_model_gpu_device_name(llmodel_model model);
int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error);
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback);
#ifdef __cplusplus
}
#endif

View File

@ -202,7 +202,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const
if (keyidx != -1) {
value = gguf_get_val_u32(ctx, keyidx);
} else {
std::cerr << __func__ << ": " << key << "not found in " << modelPath << "\n";
std::cerr << __func__ << ": " << key << " not found in " << modelPath << "\n";
}
}
@ -518,18 +518,13 @@ size_t LLamaModel::restoreState(std::span<const uint8_t> state, std::span<const
return bytesRead;
}
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str, bool special)
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str) const
{
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
std::vector<LLModel::Token> fres(str.length() + 4);
int32_t fres_len = llama_tokenize_gpt4all(
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
/*parse_special*/ special, /*insert_space*/ insertSpace
int32_t fres_len = llama_tokenize(
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ true, /*parse_special*/ true
);
fres.resize(fres_len);
if (fres_len)
m_tokenize_last_token = fres.back();
return fres;
}
@ -555,7 +550,7 @@ std::string LLamaModel::tokenToString(Token id) const
return std::string(result.data(), result.size());
}
void LLamaModel::initSampler(PromptContext &promptCtx)
void LLamaModel::initSampler(const PromptContext &promptCtx)
{
auto *model = d_ptr->model;
auto *chain = d_ptr->sampler_chain;
@ -601,9 +596,11 @@ LLModel::Token LLamaModel::sampleToken() const
return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1);
}
bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) const
bool LLamaModel::evalTokens(int32_t nPast, std::span<const Token> tokens) const
{
llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1);
assert(!tokens.empty());
llama_kv_cache_seq_rm(d_ptr->ctx, 0, nPast, -1);
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
@ -611,7 +608,7 @@ bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) c
for (int32_t i = 0; i < batch.n_tokens; i++) {
batch.token [i] = tokens[i];
batch.pos [i] = ctx.n_past + i;
batch.pos [i] = nPast + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i][0] = 0;
batch.logits [i] = false;
@ -625,13 +622,13 @@ bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) c
return res == 0;
}
void LLamaModel::shiftContext(PromptContext &promptCtx)
void LLamaModel::shiftContext(const PromptContext &promptCtx, int32_t *nPast)
{
// infinite text generation via context shifting
// erase up to n_ctx*contextErase tokens
int n_keep = shouldAddBOS();
int n_past = promptCtx.n_past;
int n_past = *nPast;
int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase));
assert(n_discard > 0);
@ -647,7 +644,7 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
auto &inp = d_ptr->inputTokens;
inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard);
promptCtx.n_past = inp.size();
*nPast = inp.size();
}
int32_t LLamaModel::contextLength() const
@ -655,39 +652,37 @@ int32_t LLamaModel::contextLength() const
return llama_n_ctx(d_ptr->ctx);
}
auto LLamaModel::specialTokens() -> std::unordered_map<std::string, std::string> const
{
if (!d_ptr->model)
throw std::logic_error("model not loaded");
std::unordered_map<std::string, std::string> tokens;
if (auto id = llama_token_bos(d_ptr->model); id != LLAMA_TOKEN_NULL)
tokens.emplace("bos_token", tokenToString(id));
if (auto id = llama_token_eos(d_ptr->model); id != LLAMA_TOKEN_NULL)
tokens.emplace("eos_token", tokenToString(id));
return tokens;
}
int32_t LLamaModel::inputLength() const
{
return d_ptr->inputTokens.size();
}
void LLamaModel::setTokenizeInputPosition(int32_t pos)
int32_t LLamaModel::computeModelInputPosition(std::span<const Token> input) const
{
assert(pos >= 0);
m_tokenize_last_token = pos ? d_ptr->inputTokens.at(size_t(pos) - 1) : -1; // not serialized
}
auto LLamaModel::computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator
{
assert(ctx.n_past >= 0);
auto pos = size_t(ctx.n_past);
if (pos > d_ptr->inputTokens.size()) {
std::ostringstream ss;
ss << "n_past=" << pos << " is past end of token cache length=" << d_ptr->inputTokens.size();
throw std::out_of_range(ss.str());
}
// find common prefix
auto cacheIt = d_ptr->inputTokens.begin();
auto inputIt = input.begin();
while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) {
++cacheIt; ++inputIt; ++pos;
++cacheIt; ++inputIt;
}
// tell the caller to ignore the tokens between [begin, inputIt)
return inputIt;
return inputIt - input.begin();
}
void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
void LLamaModel::setModelInputPosition(int32_t pos)
{
auto &inp = d_ptr->inputTokens;
assert(pos >= 0);
@ -695,13 +690,11 @@ void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
// truncate token cache to end at the new n_past
if (pos < inp.size())
inp.resize(pos);
ctx.n_past = pos;
}
void LLamaModel::appendInputToken(PromptContext &ctx, Token tok)
void LLamaModel::appendInputToken(Token tok)
{
d_ptr->inputTokens.push_back(tok);
ctx.n_past += 1;
}
auto LLamaModel::inputTokens() const -> std::span<const Token>
@ -729,6 +722,37 @@ int32_t LLamaModel::layerCount(std::string const &modelPath) const
return get_arch_key_u32(modelPath, "block_count");
}
// TODO(jared): reduce redundant code and operations by combining all metadata getters for unloaded
// models into a class that keeps the model file open
auto LLamaModel::chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string>
{
auto *ctx = load_gguf(modelPath);
if (!ctx)
return std::unexpected("failed to open model file");
std::expected<std::string, std::string> result;
enum gguf_type ktype;
const int kid = gguf_find_key(ctx, "tokenizer.chat_template");
if (kid == -1) {
result = std::unexpected("key not found");
goto cleanup;
}
ktype = gguf_get_kv_type(ctx, kid);
if (ktype != GGUF_TYPE_STRING) {
result = std::unexpected(
"expected key type STRING (" + std::to_string(GGUF_TYPE_STRING) + "), got " + std::to_string(ktype)
);
goto cleanup;
}
result = gguf_get_val_str(ctx, kid);
cleanup:
gguf_free(ctx);
return result;
}
#ifdef GGML_USE_VULKAN
static const char *getVulkanVendorName(uint32_t vendorID)
{

View File

@ -11,6 +11,7 @@
#include <string>
#include <string_view>
#include <vector>
#include <unordered_map>
struct LLamaPrivate;
struct EmbModelSpec;
@ -49,26 +50,26 @@ public:
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
int32_t contextLength() const override;
auto specialTokens() -> std::unordered_map<std::string, std::string> const override;
protected:
std::vector<Token> tokenize(std::string_view str, bool special) override;
std::vector<Token> tokenize(std::string_view str) const override;
bool isSpecialToken(Token id) const override;
std::string tokenToString(Token id) const override;
void initSampler(PromptContext &ctx) override;
void initSampler(const PromptContext &ctx) override;
Token sampleToken() const override;
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override;
void shiftContext(PromptContext &promptCtx) override;
bool evalTokens(int32_t nPast, std::span<const Token> tokens) const override;
void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override;
int32_t inputLength() const override;
void setTokenizeInputPosition(int32_t pos) override;
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator override;
void setModelInputPosition(PromptContext &ctx, int32_t pos) override;
void appendInputToken(PromptContext &ctx, Token tok) override;
int32_t computeModelInputPosition(std::span<const Token> input) const override;
void setModelInputPosition(int32_t pos) override;
void appendInputToken(Token tok) override;
std::span<const Token> inputTokens() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override;
int32_t maxContextLength(std::string const &modelPath) const override;
int32_t layerCount(std::string const &modelPath) const override;
auto chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string> override;
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,

View File

@ -326,6 +326,12 @@ bool LLModel::Implementation::isEmbeddingModel(const std::string &modelPath)
return llama && llama->isEmbeddingModel(modelPath);
}
auto LLModel::Implementation::chatTemplate(const char *modelPath) -> std::expected<std::string, std::string>
{
auto *llama = constructGlobalLlama();
return llama ? llama->chatTemplate(modelPath) : std::unexpected("backend not available");
}
void LLModel::Implementation::setImplementationsSearchPath(const std::string& path)
{
s_implementations_search_path = path;

View File

@ -7,7 +7,6 @@
#include <cstdlib>
#include <cstring>
#include <exception>
#include <functional>
#include <iostream>
#include <memory>
#include <optional>
@ -22,7 +21,6 @@ static_assert(sizeof(token_t) == sizeof(LLModel::Token));
struct LLModelWrapper {
LLModel *llModel = nullptr;
LLModel::PromptContext promptContext;
~LLModelWrapper() { delete llModel; }
};
@ -126,49 +124,44 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6
return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)});
}
void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
bool allow_context_shift,
llmodel_prompt_context *ctx,
bool special,
const char *fake_reply)
bool llmodel_prompt(llmodel_model model,
const char *prompt,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_prompt_context *ctx,
const char **error)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
auto response_func = [response_callback](int32_t token_id, const std::string &response) {
return response_callback(token_id, response.c_str());
// Copy the C prompt context
LLModel::PromptContext promptContext {
.n_predict = ctx->n_predict,
.top_k = ctx->top_k,
.top_p = ctx->top_p,
.min_p = ctx->min_p,
.temp = ctx->temp,
.n_batch = ctx->n_batch,
.repeat_penalty = ctx->repeat_penalty,
.repeat_last_n = ctx->repeat_last_n,
.contextErase = ctx->context_erase,
};
// Copy the C prompt context
wrapper->promptContext.n_past = ctx->n_past;
wrapper->promptContext.n_predict = ctx->n_predict;
wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p;
wrapper->promptContext.min_p = ctx->min_p;
wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
auto prompt_func = [prompt_callback](std::span<const LLModel::Token> token_ids, bool cached) {
return prompt_callback(token_ids.data(), token_ids.size(), cached);
};
auto response_func = [response_callback](LLModel::Token token_id, std::string_view piece) {
return response_callback(token_id, piece.data());
};
// Call the C++ prompt method
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
wrapper->promptContext, special,
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt);
try {
wrapper->llModel->prompt(prompt, prompt_func, response_func, promptContext);
} catch (std::exception const &e) {
llmodel_set_error(error, e.what());
return false;
}
// Update the rest of the C prompt context
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
ctx->min_p = wrapper->promptContext.min_p;
ctx->temp = wrapper->promptContext.temp;
ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
ctx->context_erase = wrapper->promptContext.contextErase;
return true;
}
float *llmodel_embed(
@ -307,3 +300,21 @@ const char *llmodel_model_gpu_device_name(llmodel_model model)
const auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->gpuDeviceName();
}
int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error)
{
auto *wrapper = static_cast<const LLModelWrapper *>(model);
try {
return wrapper->llModel->countPromptTokens(prompt);
} catch (const std::exception& e) {
llmodel_set_error(error, e.what());
return -1;
}
}
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback)
{
auto *wrapper = static_cast<const LLModelWrapper *>(model);
for (auto &[name, token] : wrapper->llModel->specialTokens())
callback(name.c_str(), token.c_str());
}

View File

@ -4,232 +4,120 @@
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iostream>
#include <iterator>
#include <optional>
#include <regex>
#include <sstream>
#include <ranges>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>
namespace ranges = std::ranges;
namespace views = std::ranges::views;
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err)
{
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))");
void LLModel::prompt(
std::string_view prompt,
const PromptCallback &promptCallback,
const ResponseCallback &responseCallback,
const PromptContext &promptCtx
) {
if (!isModelLoaded())
throw std::invalid_argument("Attempted to prompt an unloaded model.");
if (!supportsCompletion())
throw std::invalid_argument("Not a text completion model.");
if (!promptCtx.n_batch)
throw std::invalid_argument("Batch size cannot be zero.");
if (!promptCtx.n_predict)
return; // nothing requested
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex);
placeholders.clear();
placeholders.insert(placeholders.end(), it, std::sregex_iterator());
auto embd_inp = tokenize(prompt);
if (embd_inp.empty())
throw std::invalid_argument("Prompt tokenized to zero tokens.");
if (placeholders.size() > 2) {
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size());
return false;
}
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
return false;
}
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
return false;
}
return true;
if (auto res = decodePrompt(promptCallback, promptCtx, std::move(embd_inp)))
generateResponse(responseCallback, promptCtx, /*n_past*/ *res);
}
void LLModel::prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::optional<std::string_view> fakeReply)
int32_t LLModel::countPromptTokens(std::string_view prompt) const
{
if (!isModelLoaded()) {
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
return;
}
if (!supportsCompletion()) {
std::string errorMessage = "ERROR: this model does not support text completion or chat!";
responseCallback(-1, errorMessage);
std::cerr << implementation().modelType() << " " << errorMessage << "\n";
return;
}
// sanity checks
if (promptCtx.n_past > contextLength()) {
std::ostringstream ss;
ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength();
throw std::out_of_range(ss.str());
}
if (promptCtx.n_past > inputLength()) {
std::ostringstream ss;
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << inputLength();
throw std::out_of_range(ss.str());
}
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
// parse the prompt template
std::vector<std::smatch> placeholders;
{
std::string err;
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
responseCallback(-1, err);
std::cerr << err << "\n";
return;
}
}
setTokenizeInputPosition(promptCtx.n_past);
// tokenize the user prompt
std::vector<Token> embd_inp;
if (placeholders.empty()) {
// this is unusual, but well-defined
std::cerr << __func__ << ": prompt template has no placeholder\n";
embd_inp = tokenize(promptTemplate, true);
} else {
// template: beginning of user prompt
const auto &phUser = placeholders[0];
std::string userPrefix(phUser.prefix());
if (!userPrefix.empty())
embd_inp = tokenize(userPrefix, true);
// user input (shouldn't have special token processing)
auto tokens = tokenize(prompt, special);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
// template: end of user prompt + start of assistant prompt
size_t start = phUser.position() + phUser.length();
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
auto userToAsst = promptTemplate.substr(start, end - start);
if (!userToAsst.empty()) {
tokens = tokenize(userToAsst, true);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
}
}
// decode the user prompt
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, /*isResponse*/ false,
/*alwaysDecode*/ true))
return; // error
// decode the assistant's reply, either generated or spoofed
if (!fakeReply) {
generateResponse(responseCallback, allowContextShift, promptCtx);
} else {
embd_inp = tokenize(*fakeReply, false);
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true))
return; // error
}
// decode the rest of the prompt template
// template: end of assistant prompt
std::string asstSuffix;
if (placeholders.size() >= 2) {
size_t start = placeholders[1].position() + placeholders[1].length();
asstSuffix = promptTemplate.substr(start);
} else {
asstSuffix = "\n\n"; // default to a blank link, good for e.g. Alpaca
}
if (!asstSuffix.empty()) {
embd_inp = tokenize(asstSuffix, true);
decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
}
if (!isModelLoaded())
throw std::invalid_argument("Attempted to tokenize with an unloaded model.");
return int32_t(tokenize(prompt).size());
}
// returns false on error
bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp,
bool isResponse,
bool alwaysDecode) {
if ((int) embd_inp.size() > contextLength() - 4) {
// FIXME: (Adam) We should find a way to bubble these strings to the UI level to allow for
// translation
responseCallback(-1, "Your message was too long and could not be processed. Please try again with something shorter.");
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
" tokens and the context window is " << contextLength() << "!\n";
return false;
}
auto LLModel::decodePrompt(
const PromptCallback &promptCallback,
const PromptContext &promptCtx,
std::vector<Token> embd_inp
) -> std::optional<int32_t>
{
assert(!embd_inp.empty());
// FIXME(jared): There are mitigations for this situation, such as making room before
// copying the prompt context, or restoring the KV cache when we restore the prompt
// context.
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size()
<< ", n_ctx=" << contextLength() << "\n";
return false;
}
// always decode something before generating, even if cached
if (alwaysDecode && embd_inp.empty()) {
auto cache = inputTokens();
if (!promptCtx.n_past)
throw std::runtime_error("zero token prompt is not supported");
assert(!cache.empty());
embd_inp.push_back(cache.back());
promptCtx.n_past--;
}
int32_t nCtx = contextLength();
int32_t n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
// Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the
// requested n_past.
// This is used to skip unnecessary work when the prompt shares a common prefix with the previous result.
auto embd_inp_start = computeModelInputPosition(promptCtx, embd_inp);
size_t start_offset = embd_inp_start - embd_inp.begin();
int32_t nPast = computeModelInputPosition(embd_inp);
// always decode up to a full batch before generating, even if cached
if (alwaysDecode)
start_offset -= std::min(promptCtx.n_batch, int32_t(start_offset));
nPast -= std::min(n_batch, nPast);
setModelInputPosition(promptCtx, promptCtx.n_past + start_offset);
// TODO(jared): generalize this to find the smallest new_embd_inp.size() - nPast given the cache
if (!nPast && int32_t(embd_inp.size()) > nCtx) {
// no cache hit -> shift the input before even processing
// execute the callback even for skipped tokens
size_t i = 0;
for (; i < start_offset; i++) {
Token tok = embd_inp[i];
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
if (!res)
return false;
int32_t nKeep = shouldAddBOS();
auto newLength = int32_t(nCtx * (1.f - promptCtx.contextErase));
int32_t nDiscard = int32_t(embd_inp.size()) - std::max(1, std::min(nCtx, newLength));
// execute the callback even for skipped tokens. this misrepresents the position of BOS but we don't care
auto discardedTokens = embd_inp | views::drop(nKeep) | views::take(nDiscard);
if (!promptCallback(discardedTokens, true))
return std::nullopt;
// erase nDiscard tokens
embd_inp.erase(discardedTokens.begin(), discardedTokens.end());
assert(int32_t(embd_inp.size()) <= nCtx);
// check the cache again, just in case
nPast = computeModelInputPosition(embd_inp);
nPast -= std::min(n_batch, nPast);
}
setModelInputPosition(nPast);
// execute the callback even for skipped tokens
if (!promptCallback(embd_inp | views::take(nPast), true))
return std::nullopt;
// process the prompt in batches
while (i < embd_inp.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
std::span<const Token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
for (int32_t i = nPast; i < embd_inp.size();) {
auto batch_end = std::min(i + n_batch, int32_t(embd_inp.size()));
std::span batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
// Check if the context has run out...
if (promptCtx.n_past + int32_t(batch.size()) > contextLength()) {
assert(allowContextShift);
shiftContext(promptCtx);
assert(promptCtx.n_past + int32_t(batch.size()) <= contextLength());
if (nPast + int32_t(batch.size()) > nCtx) {
shiftContext(promptCtx, &nPast);
assert(nPast + int32_t(batch.size()) <= nCtx);
}
if (!evalTokens(promptCtx, batch)) {
std::cerr << implementation().modelType() << " ERROR: Failed to process prompt\n";
return false;
}
// FIXME(Adam): We should find a way to bubble these strings to the UI level to allow for translation
if (!evalTokens(nPast, batch))
throw std::runtime_error("An internal error was encountered during prompt processing.");
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t) {
Token tok = batch[t];
appendInputToken(promptCtx, tok);
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
if (!res)
return false;
for (auto &tok : batch) {
appendInputToken(tok);
nPast++;
if (!promptCallback({ &tok, 1 }, false))
return std::nullopt;
}
i = batch_end;
}
return true;
return nPast;
}
/*
@ -251,22 +139,16 @@ static std::string::size_type stringsOverlap(const std::string &s, const std::st
return std::string::npos;
}
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx) {
void LLModel::generateResponse(
const ResponseCallback &responseCallback,
const PromptContext &promptCtx,
int32_t nPast
) {
static const char *stopSequences[] {
"### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context",
"### System", "### Instruction", "### Human", "### User", "### Response", "### Assistant", "### Context",
"<|im_start|>", "<|im_end|>", "<|endoftext|>",
};
// Don't even start if there is no room
if (!promptCtx.n_predict)
return;
if (!allowContextShift && promptCtx.n_past >= contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << contextLength()
<< "\n";
return;
}
initSampler(promptCtx);
std::string cachedResponse;
@ -281,25 +163,20 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
cachedTokens.push_back(new_tok.value());
cachedResponse += new_piece;
auto accept = [this, &promptCtx, &new_tok, allowContextShift]() -> bool {
auto accept = [this, &promptCtx, &new_tok, &nPast] {
// Shift context if out of space
if (promptCtx.n_past >= contextLength()) {
(void)allowContextShift;
assert(allowContextShift);
shiftContext(promptCtx);
assert(promptCtx.n_past < contextLength());
if (nPast >= contextLength()) {
shiftContext(promptCtx, &nPast);
assert(nPast < contextLength());
}
// Accept the token
Token tok = std::exchange(new_tok, std::nullopt).value();
if (!evalTokens(promptCtx, { &tok, 1 })) {
// TODO(jared): raise an exception
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
return false;
}
if (!evalTokens(nPast, { &tok, 1 }))
throw std::runtime_error("An internal error was encountered during response generation.");
appendInputToken(promptCtx, tok);
return true;
appendInputToken(tok);
nPast++;
};
// Check for EOS
@ -336,13 +213,6 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
lengthLimit = cachedResponse.size() - new_piece.size();
}
// Optionally stop if the context will run out
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx="
<< contextLength() << "\n";
stop = true;
}
// Empty the cache, up to the length limit
std::string::size_type responseLength = 0;
while (!cachedTokens.empty()) {
@ -359,8 +229,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
cachedResponse.erase(cachedResponse.begin(), cachedResponse.begin() + piece.size());
// Accept the token, if needed (not cached)
if (cachedTokens.empty() && new_tok && !accept())
return;
if (cachedTokens.empty() && new_tok)
accept();
// Send the token
if (!responseCallback(tok, piece) || ++n_predicted >= promptCtx.n_predict) {
@ -379,8 +249,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
assert(!cachedTokens.empty() && cachedTokens.back() == new_tok);
if (stop) {
cachedTokens.pop_back();
} else if (!accept()) {
return;
} else {
accept();
}
}
}
@ -396,8 +266,6 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
auto discard_start = inp.end() - cachedTokens.size();
assert(std::equal(discard_start, inp.end(), cachedTokens.begin()));
#endif
promptCtx.n_past -= cachedTokens.size();
}
void LLModel::embed(

View File

@ -0,0 +1,206 @@
## What are chat templates?
Natively, large language models only know how to complete plain text and do not know the difference between their input and their output. In order to support a chat with a person, LLMs are designed to use a template to convert the conversation to plain text using a specific format.
For a given model, it is important to use an appropriate chat template, as each model is designed to work best with a specific format. The chat templates included with the built-in models should be sufficient for most purposes.
There are two reasons you would want to alter the chat template:
- You are sideloading a model and there is no chat template available,
- You would like to have greater control over the input to the LLM than a system message provides.
## What is a system message?
A system message is a message that controls the responses from the LLM in a way that affects the entire conversation. System messages can be short, such as "Speak like a pirate.", or they can be long and contain a lot of context for the LLM to keep in mind.
Not all models are designed to use a system message, so they work with some models better than others.
## How do I customize the chat template or system message?
To customize the chat template or system message, go to Settings > Model. Make sure to select the correct model at the top. If you clone a model, you can use a different chat template or system message from the base model, enabling you to use different settings for each conversation.
These settings take effect immediately. After changing them, you can click "Redo last response" in the chat view, and the response will take the new settings into account.
## Do I need to write a chat template?
You typically do not need to write your own chat template. The exception is models that are not in the official model list and do not come with a chat template built-in. These will show a "Clear" option above the chat template field in the Model Settings page instead of a "Reset" option. See the section on [finding] or [creating] a chat template.
[finding]: #how-do-i-find-a-chat-template
[creating]: #advanced-how-do-chat-templates-work
## What changed in GPT4All v3.5?
GPT4All v3.5 overhauled the chat template system. There are three crucial differences:
- The chat template now formats an entire conversation instead of a single pair of messages,
- The chat template now uses Jinja syntax instead of `%1` and `%2` placeholders,
- And the system message should no longer contain control tokens or trailing whitespace.
If you are using any chat templates or system messages that had been added or altered from the default before upgrading to GPT4All v3.5 or newer, these will no longer work. See below for how to solve common errors you may see after upgrading.
## Error/Warning: System message is not plain text.
This is easy to fix. Go to the model's settings and look at the system prompt. There are three things to look for:
- Control tokens such as `<|im_start|>`, `<|start_header_id|>`, or `<|system|>`
- A prefix such as `### System` or `SYSTEM:`
- Trailing whitespace, such as a space character or blank line.
If you see any of these things, remove them. For example, this legacy system prompt:
```
<|start_header_id|>system<|end_header_id|>
You are a helpful assistant.<|eot_id|>
```
Should become this:
```
You are a helpful assistant.
```
If you do not see anything that needs to be changed, you can dismiss the error by making a minor modification to the message and then changing it back.
If you see a warning, your system message does not appear to be plain text. If you believe this warning is incorrect, it can be safely ignored. If in doubt, ask on the [Discord].
[Discord]: https://discord.gg/mGZE39AS3e
## Error: Legacy system prompt needs to be updated in Settings.
This is the same as [above][above-1], but appears on the chat page.
[above-1]: #errorwarning-system-message-is-not-plain-text
## Error/Warning: Chat template is not in Jinja format.
This is the result of attempting to use an old-style template (possibly from a previous version) in GPT4All 3.5+.
Go to the Model Settings page and select the affected model. If you see a "Reset" button, and you have not intentionally modified the prompt template, you can click "Reset". Otherwise, this is what you can do:
1. Back up your chat template by copying it safely to a text file and saving it. In the next step, it will be removed from GPT4All.
2. Click "Reset" or "Clear".
3. If you clicked "Clear", the chat template is now gone. Follow the steps to [find][finding] or [create][creating] a basic chat template for your model.
4. Customize the chat template to suit your needs. For help, read the section about [creating] a chat template.
## Error: Legacy prompt template needs to be updated in Settings.
This is the same as [above][above-2], but appears on the chat page.
[above-2]: #errorwarning-chat-template-is-not-in-jinja-format
## The chat template has a syntax error.
If there is a syntax error while editing the chat template, the details will be displayed in an error message above the input box. This could be because the chat template is not actually in Jinja format (see [above][above-2]).
Otherwise, you have either typed something correctly, or the model comes with a template that is incompatible with GPT4All. See [the below section][creating] on creating chat templates and make sure that everything is correct. When in doubt, ask on the [Discord].
## Error: No chat template configured.
This may appear for models that are not from the official model list and do not include a chat template. Older versions of GPT4All picked a poor default in this case. You will get much better results if you follow the steps to [find][finding] or [create][creating] a chat template for your model.
## Error: The chat template cannot be blank.
If the button above the chat template on the Model Settings page says "Clear", see [above][above-3]. If you see "Reset", click that button to restore a reasonable default. Also see the section on [syntax errors][chat-syntax-error].
[above-3]: #error-no-chat-template-configured
[chat-syntax-error]: #the-chat-template-has-a-syntax-error
## How do I find a chat template?
When in doubt, you can always ask the [Discord] community for help. Below are the instructions to find one on your own.
The authoritative source for a model's chat template is the HuggingFace repo that the original (non-GGUF) model came from. First, you should find this page. If you just have a model file, you can try a google search for the model's name. If you know the page you downloaded the GGUF model from, its README usually links to the original non-GGUF model.
Once you have located the original model, there are two methods you can use to extract its chat template. Pick whichever one you are most comfortable with.
### Using the CLI (all models)
1. Install `jq` using your preferred package manager - e.g. Chocolatey (Windows), Homebrew (macOS), or apt (Ubuntu).
2. Download `tokenizer_config.json` from the model's "Files and versions" tab.
3. Open a command prompt in the directory which you have downloaded the model file.
4. Run `jq -r ".chat_template" tokenizer_config.json`. This shows the chat template in a human-readable form. You can copy this and paste it into the settings page.
5. (Optional) You can save the output to a text file like this: `jq -r ".chat_template" tokenizer_config.json >chat_template.txt`
If the output is "null", the model does not provide a chat template. See the [below instructions][creating] on creating a chat template.
### Python (open models)
1. Install `transformers` using your preferred python package manager, e.g. `pip install transformers`. Make sure it is at least version v4.43.0.
2. Copy the ID of the HuggingFace model, using the clipboard icon next to the name. For example, if the URL is `https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B`, the ID is `NousResearch/Hermes-2-Pro-Llama-3-8B`.
3. Open a python interpreter (`python`) and run the following commands. Change the model ID in the example to the one you copied.
```
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained('NousResearch/Hermes-2-Pro-Llama-3-8B')
>>> print(tokenizer.get_chat_template())
```
You can copy the output and paste it into the settings page.
4. (Optional) You can save the output to a text file like this:
```
>>> open('chat_template.txt', 'w').write(tokenizer.get_chat_template())
```
If you get a ValueError exception, this model does not provide a chat template. See the [below instructions][creating] on creating a chat template.
### Python (gated models)
Some models, such as Llama and Mistral, do not allow public access to their chat template. You must either use the CLI method above, or follow the following instructions to use Python:
1. For these steps, you must have git and git-lfs installed.
2. You must have a HuggingFace account and be logged in.
3. You must already have access to the gated model. Otherwise, request access.
4. You must have an SSH key configured for git access to HuggingFace.
5. `git clone` the model's HuggingFace repo using the SSH clone URL. There is no need to download the entire model, which is very large. A good way to do this on Linux is:
```console
$ GIT_LFS_SKIP_SMUDGE=1 git clone hf.co:meta-llama/Llama-3.1-8B-Instruct.git
$ cd Llama-3.1-8B-Instruct
$ git lfs pull -I "tokenizer.*"
```
6. Follow the above instructions for open models, but replace the model ID with the path to the directory containing `tokenizer\_config.json`:
```
>>> tokenizer = AutoTokenizer.from_pretrained('.')
```
## Advanced: How do chat templates work?
The chat template is applied to the entire conversation you see in the chat window. The template loops over the list of messages, each containing `role` and `content` fields. `role` is either `user`, `assistant`, or `system`.
GPT4All also supports the special variables `bos_token`, `eos_token`, and `add_generation_prompt`. See the [HuggingFace docs] for what those do.
[HuggingFace docs]: https://huggingface.co/docs/transformers/v4.46.3/en/chat_templating#special-variables
## Advanced: How do I make a chat template?
The best way to create a chat template is to start by using an existing one as a reference. Then, modify it to use the format documented for the given model. Its README page may explicitly give an example of its template. Or, it may mention the name of a well-known standard template, such as ChatML, Alpaca, Vicuna. GPT4All does not yet include presets for these templates, so they will have to be found in other models or taken from the community.
For more information, see the very helpful [HuggingFace guide]. Some of this is not applicable, such as the information about tool calling and RAG - GPT4All implements those features differently.
Some models use a prompt template that does not intuitively map to a multi-turn chat, because it is more intended for single instructions. The [FastChat] implementation of these templates is a useful reference for the correct way to extend them to multiple messages.
[HuggingFace guide]: https://huggingface.co/docs/transformers/v4.46.3/en/chat_templating#advanced-template-writing-tips
[FastChat]: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# Advanced: What are GPT4All v1 templates?
GPT4All supports its own template syntax, which is nonstandard but provides complete control over the way LocalDocs sources and file attachments are inserted into the conversation. These templates begin with `{# gpt4all v1 #}` and look similar to the example below.
For standard templates, GPT4All combines the user message, sources, and attachments into the `content` field. For GPT4All v1 templates, this is not done, so they must be used directly in the template for those features to work correctly.
```jinja
{# gpt4all v1 #}
{%- for message in messages %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
{%- if message['role'] == 'user' %}
{%- for source in message['sources'] %}
{%- if loop.first %}
{{- '### Context:\n' }}
{%- endif %}
{{- 'Collection: ' + source['collection'] + '\n' +
'Path: ' + source['path'] + '\n' +
'Excerpt: ' + source['text'] + '\n\n' }}
{%- endfor %}
{%- endif %}
{%- for attachment in message['prompt_attachments'] %}
{{- attachment['processed_content'] + '\n\n' }}
{%- endfor %}
{{- message['content'] | trim }}
{{- '<|eot_id|>' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
```

View File

@ -9,7 +9,7 @@ import textwrap
import threading
from enum import Enum
from queue import Queue
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Literal, NoReturn, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload
if sys.version_info >= (3, 9):
import importlib.resources as importlib_resources
@ -23,7 +23,9 @@ else:
from typing import TypedDict
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from typing_extensions import ParamSpec, TypeAlias
T = TypeVar("T")
P = ParamSpec("P")
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
@ -31,7 +33,7 @@ cuda_found: bool = False
# TODO(jared): use operator.call after we drop python 3.10 support
def _operator_call(obj, /, *args, **kwargs):
def _operator_call(obj: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
return obj(*args, **kwargs)
@ -116,16 +118,15 @@ llmodel = load_llmodel_library()
class LLModelPromptContext(ctypes.Structure):
_fields_ = [
("n_past", ctypes.c_int32),
("n_predict", ctypes.c_int32),
("top_k", ctypes.c_int32),
("top_p", ctypes.c_float),
("min_p", ctypes.c_float),
("temp", ctypes.c_float),
("n_batch", ctypes.c_int32),
("n_predict", ctypes.c_int32),
("top_k", ctypes.c_int32),
("top_p", ctypes.c_float),
("min_p", ctypes.c_float),
("temp", ctypes.c_float),
("n_batch", ctypes.c_int32),
("repeat_penalty", ctypes.c_float),
("repeat_last_n", ctypes.c_int32),
("context_erase", ctypes.c_float),
("repeat_last_n", ctypes.c_int32),
("context_erase", ctypes.c_float),
]
@ -157,23 +158,21 @@ llmodel.llmodel_required_mem.restype = ctypes.c_size_t
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_int32), ctypes.c_size_t, ctypes.c_bool)
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
SpecialTokenCallback = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p)
llmodel.llmodel_prompt.argtypes = [
ctypes.c_void_p,
ctypes.c_char_p,
ctypes.c_char_p,
PromptCallback,
ResponseCallback,
ctypes.c_bool,
ctypes.POINTER(LLModelPromptContext),
ctypes.c_bool,
ctypes.c_char_p,
ctypes.POINTER(ctypes.c_char_p),
]
llmodel.llmodel_prompt.restype = None
llmodel.llmodel_prompt.restype = ctypes.c_bool
llmodel.llmodel_embed.argtypes = [
ctypes.c_void_p,
@ -222,6 +221,12 @@ llmodel.llmodel_model_backend_name.restype = ctypes.c_char_p
llmodel.llmodel_model_gpu_device_name.argtypes = [ctypes.c_void_p]
llmodel.llmodel_model_gpu_device_name.restype = ctypes.c_char_p
llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char_p)]
llmodel.llmodel_count_prompt_tokens.restype = ctypes.c_int32
llmodel.llmodel_model_foreach_special_token.argtypes = [ctypes.c_void_p, SpecialTokenCallback]
llmodel.llmodel_model_foreach_special_token.restype = None
ResponseCallbackType = Callable[[int, str], bool]
RawResponseCallbackType = Callable[[int, bytes], bool]
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
@ -266,7 +271,6 @@ class LLModel:
self.model_path = model_path.encode()
self.n_ctx = n_ctx
self.ngl = ngl
self.context: LLModelPromptContext | None = None
self.buffer = bytearray()
self.buff_expecting_cont_bytes: int = 0
@ -286,6 +290,10 @@ class LLModel:
raise RuntimeError(f"Unable to instantiate model: {errmsg}")
self.model: ctypes.c_void_p | None = model
self.special_tokens_map: dict[str, str] = {}
llmodel.llmodel_model_foreach_special_token(
self.model, lambda n, t: self.special_tokens_map.__setitem__(n.decode(), t.decode()),
)
def __del__(self, llmodel=llmodel):
if hasattr(self, 'model'):
@ -312,6 +320,19 @@ class LLModel:
dev = llmodel.llmodel_model_gpu_device_name(self.model)
return None if dev is None else dev.decode()
def count_prompt_tokens(self, prompt: str) -> int:
if self.model is None:
self._raise_closed()
err = ctypes.c_char_p()
n_tok = llmodel.llmodel_count_prompt_tokens(self.model, prompt, ctypes.byref(err))
if n_tok < 0:
s = err.value
errmsg = 'null' if s is None else s.decode()
raise RuntimeError(f'Unable to count prompt tokens: {errmsg}')
return n_tok
llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
@staticmethod
def list_gpus(mem_required: int = 0) -> list[str]:
"""
@ -375,48 +396,6 @@ class LLModel:
raise Exception("Model not loaded")
return llmodel.llmodel_threadCount(self.model)
def _set_context(
self,
n_predict: int = 4096,
top_k: int = 40,
top_p: float = 0.9,
min_p: float = 0.0,
temp: float = 0.1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = 0.75,
reset_context: bool = False,
):
if self.context is None:
context = LLModelPromptContext(
n_past=0,
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
min_p=min_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase,
)
self.context = context
else:
context = self.context
if reset_context:
self.context.n_past = 0
self.context.n_predict = n_predict
self.context.top_k = top_k
self.context.top_p = top_p
self.context.min_p = min_p
self.context.temp = temp
self.context.n_batch = n_batch
self.context.repeat_penalty = repeat_penalty
self.context.repeat_last_n = repeat_last_n
self.context.context_erase = context_erase
@overload
def generate_embeddings(
self, text: str, prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
@ -486,20 +465,18 @@ class LLModel:
def prompt_model(
self,
prompt: str,
prompt_template: str,
callback: ResponseCallbackType,
n_predict: int = 4096,
top_k: int = 40,
top_p: float = 0.9,
min_p: float = 0.0,
temp: float = 0.1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = 0.75,
reset_context: bool = False,
special: bool = False,
prompt : str,
callback : ResponseCallbackType,
n_predict : int = 4096,
top_k : int = 40,
top_p : float = 0.9,
min_p : float = 0.0,
temp : float = 0.1,
n_batch : int = 8,
repeat_penalty : float = 1.2,
repeat_last_n : int = 10,
context_erase : float = 0.75,
reset_context : bool = False,
):
"""
Generate response from model from a prompt.
@ -522,34 +499,38 @@ class LLModel:
self.buffer.clear()
self.buff_expecting_cont_bytes = 0
self._set_context(
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
min_p=min_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase,
reset_context=reset_context,
context = LLModelPromptContext(
n_predict = n_predict,
top_k = top_k,
top_p = top_p,
min_p = min_p,
temp = temp,
n_batch = n_batch,
repeat_penalty = repeat_penalty,
repeat_last_n = repeat_last_n,
context_erase = context_erase,
)
llmodel.llmodel_prompt(
error_msg: bytes | None = None
def error_callback(msg: bytes) -> None:
nonlocal error_msg
error_msg = msg
err = ctypes.c_char_p()
if not llmodel.llmodel_prompt(
self.model,
ctypes.c_char_p(prompt.encode()),
ctypes.c_char_p(prompt_template.encode()),
PromptCallback(self._prompt_callback),
ResponseCallback(self._callback_decoder(callback)),
True,
self.context,
special,
ctypes.c_char_p(),
)
context,
ctypes.byref(err),
):
s = err.value
raise RuntimeError(f"prompt error: {'null' if s is None else s.decode()}")
def prompt_model_streaming(
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
) -> Iterable[str]:
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs: Any,
) -> Iterator[str]:
if self.model is None:
self._raise_closed()
@ -568,15 +549,15 @@ class LLModel:
return _generator_callback
def run_llmodel_prompt(prompt: str, prompt_template: str, callback: ResponseCallbackType, **kwargs):
self.prompt_model(prompt, prompt_template, callback, **kwargs)
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
self.prompt_model(prompt, callback, **kwargs)
output_queue.put(Sentinel.TERMINATING_SYMBOL)
# Kick off llmodel_prompt in separate thread so we can return generator
# immediately
thread = threading.Thread(
target=run_llmodel_prompt,
args=(prompt, prompt_template, _generator_callback_wrapper(callback)),
args=(prompt, _generator_callback_wrapper(callback)),
kwargs=kwargs,
)
thread.start()
@ -631,5 +612,5 @@ class LLModel:
# Empty prompt callback
@staticmethod
def _prompt_callback(token_id: int) -> bool:
def _prompt_callback(token_ids: ctypes._Pointer[ctypes.c_int32], n_token_ids: int, cached: bool) -> bool:
return True

View File

@ -4,37 +4,66 @@ Python only API for running all GPT4All models.
from __future__ import annotations
import hashlib
import json
import os
import platform
import re
import sys
import warnings
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, NamedTuple, NoReturn, Protocol, TypedDict, overload
import jinja2
import requests
from jinja2.sandbox import ImmutableSandboxedEnvironment
from requests.exceptions import ChunkedEncodingError
from tqdm import tqdm
from urllib3.exceptions import IncompleteRead, ProtocolError
from ._pyllmodel import (CancellationError as CancellationError, EmbCancelCallbackType, EmbedResult as EmbedResult,
LLModel, ResponseCallbackType, empty_response_callback)
LLModel, ResponseCallbackType, _operator_call, empty_response_callback)
if TYPE_CHECKING:
from typing_extensions import Self, TypeAlias
if sys.platform == 'darwin':
if sys.platform == "darwin":
import fcntl
# TODO: move to config
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n"
ConfigType: TypeAlias = "dict[str, Any]"
ConfigType: TypeAlias = 'dict[str, Any]'
MessageType: TypeAlias = 'dict[str, str]'
# Environment setup adapted from HF transformers
@_operator_call
def _jinja_env() -> ImmutableSandboxedEnvironment:
def raise_exception(message: str) -> NoReturn:
raise jinja2.exceptions.TemplateError(message)
def tojson(obj: Any, indent: int | None = None) -> str:
return json.dumps(obj, ensure_ascii=False, indent=indent)
def strftime_now(fmt: str) -> str:
return datetime.now().strftime(fmt)
env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
env.filters["tojson" ] = tojson
env.globals["raise_exception"] = raise_exception
env.globals["strftime_now" ] = strftime_now
return env
class MessageType(TypedDict):
role: str
content: str
class ChatSession(NamedTuple):
template: jinja2.Template
history: list[MessageType]
class Embed4All:
@ -54,7 +83,7 @@ class Embed4All:
kwargs: Remaining keyword arguments are passed to the `GPT4All` constructor.
"""
if model_name is None:
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
model_name = "all-MiniLM-L6-v2.gguf2.f16.gguf"
self.gpt4all = GPT4All(model_name, n_threads=n_threads, device=device, **kwargs)
def __enter__(self) -> Self:
@ -145,18 +174,18 @@ class Embed4All:
dimensionality = -1
else:
if dimensionality <= 0:
raise ValueError(f'Dimensionality must be None or a positive integer, got {dimensionality}')
raise ValueError(f"Dimensionality must be None or a positive integer, got {dimensionality}")
if dimensionality < self.MIN_DIMENSIONALITY:
warnings.warn(
f'Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}.'
' Performance may be degraded.'
f"Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}."
" Performance may be degraded."
)
try:
do_mean = {"mean": True, "truncate": False}[long_text_mode]
except KeyError:
raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}")
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas, cancel_cb)
return result if return_dict else result['embeddings']
return result if return_dict else result["embeddings"]
class GPT4All:
@ -204,8 +233,7 @@ class GPT4All:
"""
self.model_type = model_type
self._history: list[MessageType] | None = None
self._current_prompt_template: str = "{0}"
self._chat_session: ChatSession | None = None
device_init = None
if sys.platform == "darwin":
@ -264,7 +292,13 @@ class GPT4All:
@property
def current_chat_session(self) -> list[MessageType] | None:
return None if self._history is None else list(self._history)
return None if self._chat_session is None else self._chat_session.history
@current_chat_session.setter
def current_chat_session(self, history: list[MessageType]) -> None:
if self._chat_session is None:
raise ValueError("current_chat_session may only be set when there is an active chat session")
self._chat_session.history[:] = history
@staticmethod
def list_models() -> list[ConfigType]:
@ -276,7 +310,7 @@ class GPT4All:
"""
resp = requests.get("https://gpt4all.io/models/models3.json")
if resp.status_code != 200:
raise ValueError(f'Request failed: HTTP {resp.status_code} {resp.reason}')
raise ValueError(f"Request failed: HTTP {resp.status_code} {resp.reason}")
return resp.json()
@classmethod
@ -306,15 +340,9 @@ class GPT4All:
# get the config for the model
config: ConfigType = {}
if allow_download:
available_models = cls.list_models()
for m in available_models:
if model_filename == m["filename"]:
tmpl = m.get("promptTemplate", DEFAULT_PROMPT_TEMPLATE)
# change to Python-style formatting
m["promptTemplate"] = tmpl.replace("%1", "{0}", 1).replace("%2", "{1}", 1)
config.update(m)
break
models = cls.list_models()
if (model := next((m for m in models if m["filename"] == model_filename), None)) is not None:
config.update(model)
# Validate download directory
if model_path is None:
@ -378,13 +406,13 @@ class GPT4All:
headers = {}
if offset:
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
headers['Range'] = f'bytes={offset}-' # resume incomplete response
headers["Range"] = f"bytes={offset}-" # resume incomplete response
headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges
response = requests.get(url, stream=True, headers=headers)
if response.status_code not in (200, 206):
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
raise ValueError('Connection was interrupted and server does not support range requests')
raise ValueError(f"Request failed: HTTP {response.status_code} {response.reason}")
if offset and (response.status_code != 206 or str(offset) not in response.headers.get("Content-Range", "")):
raise ValueError("Connection was interrupted and server does not support range requests")
if (enc := response.headers.get("Content-Encoding")) is not None:
raise ValueError(f"Expected identity Content-Encoding, got {enc}")
return response
@ -483,19 +511,19 @@ class GPT4All:
def generate(
self,
prompt: str,
prompt : str,
*,
max_tokens: int = 200,
temp: float = 0.7,
top_k: int = 40,
top_p: float = 0.4,
min_p: float = 0.0,
repeat_penalty: float = 1.18,
repeat_last_n: int = 64,
n_batch: int = 8,
n_predict: int | None = None,
streaming: bool = False,
callback: ResponseCallbackType = empty_response_callback,
max_tokens : int = 200,
temp : float = 0.7,
top_k : int = 40,
top_p : float = 0.4,
min_p : float = 0.0,
repeat_penalty : float = 1.18,
repeat_last_n : int = 64,
n_batch : int = 8,
n_predict : int | None = None,
streaming : bool = False,
callback : ResponseCallbackType = empty_response_callback,
) -> Any:
"""
Generate outputs from any GPT4All model.
@ -520,122 +548,94 @@ class GPT4All:
# Preparing the model request
generate_kwargs: dict[str, Any] = dict(
temp=temp,
top_k=top_k,
top_p=top_p,
min_p=min_p,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
n_batch=n_batch,
n_predict=n_predict if n_predict is not None else max_tokens,
temp = temp,
top_k = top_k,
top_p = top_p,
min_p = min_p,
repeat_penalty = repeat_penalty,
repeat_last_n = repeat_last_n,
n_batch = n_batch,
n_predict = n_predict if n_predict is not None else max_tokens,
)
if self._history is not None:
# check if there is only one message, i.e. system prompt:
reset = len(self._history) == 1
self._history.append({"role": "user", "content": prompt})
fct_func = self._format_chat_prompt_template.__func__ # type: ignore[attr-defined]
if fct_func is GPT4All._format_chat_prompt_template:
if reset:
# ingest system prompt
# use "%1%2" and not "%1" to avoid implicit whitespace
self.model.prompt_model(self._history[0]["content"], "%1%2",
empty_response_callback,
n_batch=n_batch, n_predict=0, reset_context=True, special=True)
prompt_template = self._current_prompt_template.format("%1", "%2")
else:
warnings.warn(
"_format_chat_prompt_template is deprecated. Please use a chat session with a prompt template.",
DeprecationWarning,
)
# special tokens won't be processed
prompt = self._format_chat_prompt_template(
self._history[-1:],
self._history[0]["content"] if reset else "",
)
prompt_template = "%1"
generate_kwargs["reset_context"] = reset
else:
prompt_template = "%1"
generate_kwargs["reset_context"] = True
# Prepare the callback, process the model response
output_collector: list[MessageType]
output_collector = [
{"content": ""}
] # placeholder for the self._history if chat session is not activated
full_response = ""
if self._history is not None:
self._history.append({"role": "assistant", "content": ""})
output_collector = self._history
def _callback_wrapper(token_id: int, response: str) -> bool:
nonlocal full_response
full_response += response
return callback(token_id, response)
def _callback_wrapper(
callback: ResponseCallbackType,
output_collector: list[MessageType],
) -> ResponseCallbackType:
def _callback(token_id: int, response: str) -> bool:
nonlocal callback, output_collector
last_msg_rendered = prompt
if self._chat_session is not None:
session = self._chat_session
def render(messages: list[MessageType]) -> str:
return session.template.render(
messages=messages,
add_generation_prompt=True,
**self.model.special_tokens_map,
)
session.history.append(MessageType(role="user", content=prompt))
prompt = render(session.history)
if len(session.history) > 1:
last_msg_rendered = render(session.history[-1:])
output_collector[-1]["content"] += response
return callback(token_id, response)
return _callback
# Check request length
last_msg_len = self.model.count_prompt_tokens(last_msg_rendered)
if last_msg_len > (limit := self.model.n_ctx - 4):
raise ValueError(f"Your message was too long and could not be processed ({last_msg_len} > {limit}).")
# Send the request to the model
if streaming:
return self.model.prompt_model_streaming(
prompt,
prompt_template,
_callback_wrapper(callback, output_collector),
**generate_kwargs,
)
def stream() -> Iterator[str]:
yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs)
if self._chat_session is not None:
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
return stream()
self.model.prompt_model(
prompt,
prompt_template,
_callback_wrapper(callback, output_collector),
**generate_kwargs,
)
return output_collector[-1]["content"]
self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs)
if self._chat_session is not None:
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
return full_response
@contextmanager
def chat_session(
self,
system_prompt: str | None = None,
prompt_template: str | None = None,
system_message: str | Literal[False] | None = None,
chat_template: str | None = None,
):
"""
Context manager to hold an inference optimized chat session with a GPT4All model.
Args:
system_prompt: An initial instruction for the model.
prompt_template: Template for the prompts with {0} being replaced by the user message.
system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None.
chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
"""
if system_prompt is None:
system_prompt = self.config.get("systemPrompt", "")
if system_message is None:
system_message = self.config.get("systemMessage", False)
if prompt_template is None:
if (tmpl := self.config.get("promptTemplate")) is None:
warnings.warn("Use of a sideloaded model or allow_download=False without specifying a prompt template "
"is deprecated. Defaulting to Alpaca.", DeprecationWarning)
tmpl = DEFAULT_PROMPT_TEMPLATE
prompt_template = tmpl
if chat_template is None:
if "name" not in self.config:
raise ValueError("For sideloaded models or with allow_download=False, you must specify a chat template.")
if "chatTemplate" not in self.config:
raise NotImplementedError("This model appears to have a built-in chat template, but loading it is not "
"currently implemented. Please pass a template to chat_session() directly.")
if (tmpl := self.config["chatTemplate"]) is None:
raise ValueError(f"The model {self.config['name']!r} does not support chat.")
chat_template = tmpl
if re.search(r"%1(?![0-9])", prompt_template):
raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt "
"placeholder, please use '{0}' instead.")
self._history = [{"role": "system", "content": system_prompt}]
self._current_prompt_template = prompt_template
history = []
if system_message is not False:
history.append(MessageType(role="system", content=system_message))
self._chat_session = ChatSession(
template=_jinja_env.from_string(chat_template),
history=history,
)
try:
yield self
finally:
self._history = None
self._current_prompt_template = "{0}"
self._chat_session = None
@staticmethod
def list_gpus() -> list[str]:
@ -647,43 +647,6 @@ class GPT4All:
"""
return LLModel.list_gpus()
def _format_chat_prompt_template(
self,
messages: list[MessageType],
default_prompt_header: str = "",
default_prompt_footer: str = "",
) -> str:
"""
Helper method for building a prompt from list of messages using the self._current_prompt_template as a template for each message.
Warning:
This function was deprecated in version 2.3.0, and will be removed in a future release.
Args:
messages: List of dictionaries. Each dictionary should have a "role" key
with value of "system", "assistant", or "user" and a "content" key with a
string value. Messages are organized such that "system" messages are at top of prompt,
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
"Response: {content}".
Returns:
Formatted prompt.
"""
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
for message in messages:
if message["role"] == "user":
user_message = self._current_prompt_template.format(message["content"])
full_prompt += user_message
if message["role"] == "assistant":
assistant_message = message["content"] + "\n"
full_prompt += assistant_message
full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else ""
return full_prompt
def append_extension_if_missing(model_name):
if not model_name.endswith((".bin", ".gguf")):
@ -696,7 +659,7 @@ class _HasFileno(Protocol):
def _fsync(fd: int | _HasFileno) -> None:
if sys.platform == 'darwin':
if sys.platform == "darwin":
# Apple's fsync does not flush the drive write cache
try:
fcntl.fcntl(fd, fcntl.F_FULLFSYNC)

View File

@ -14,6 +14,7 @@ nav:
- 'Models' : 'gpt4all_desktop/models.md'
- 'LocalDocs' : 'gpt4all_desktop/localdocs.md'
- 'Settings' : 'gpt4all_desktop/settings.md'
- 'Chat Templates' : 'gpt4all_desktop/chat_templates.md'
- 'Cookbook':
- 'Local AI Chat with Microsoft Excel': 'gpt4all_desktop/cookbook/use-local-ai-models-to-privately-chat-with-microsoft-excel.md'
- 'Local AI Chat with your Google Drive': 'gpt4all_desktop/cookbook/use-local-ai-models-to-privately-chat-with-google-drive.md'

View File

@ -88,9 +88,10 @@ setup(
python_requires='>=3.8',
packages=find_packages(),
install_requires=[
'importlib_resources; python_version < "3.9"',
'jinja2~=3.1',
'requests',
'tqdm',
'importlib_resources; python_version < "3.9"',
'typing-extensions>=4.3.0; python_version >= "3.9" and python_version < "3.11"',
],
extras_require={

View File

@ -190,6 +190,7 @@ qt_add_executable(chat
src/database.cpp src/database.h
src/download.cpp src/download.h
src/embllm.cpp src/embllm.h
src/jinja_helpers.cpp src/jinja_helpers.h
src/llm.cpp src/llm.h
src/localdocs.cpp src/localdocs.h
src/localdocsmodel.cpp src/localdocsmodel.h
@ -215,6 +216,7 @@ qt_add_qml_module(chat
qml/ApplicationSettings.qml
qml/ChatDrawer.qml
qml/ChatItemView.qml
qml/ChatMessageButton.qml
qml/ChatView.qml
qml/CollectionsDrawer.qml
qml/HomeView.qml
@ -227,7 +229,7 @@ qt_add_qml_module(chat
qml/PopupDialog.qml
qml/SettingsView.qml
qml/StartupDialog.qml
qml/SwitchModelDialog.qml
qml/ConfirmationDialog.qml
qml/Theme.qml
qml/ThumbsDownDialog.qml
qml/Toast.qml
@ -386,7 +388,7 @@ target_include_directories(chat PRIVATE deps/usearch/include
target_link_libraries(chat
PRIVATE Qt6::Core Qt6::HttpServer Qt6::Pdf Qt6::Quick Qt6::Sql Qt6::Svg)
target_link_libraries(chat
PRIVATE llmodel SingleApplication fmt::fmt duckx::duckx QXlsx)
PRIVATE llmodel SingleApplication fmt::fmt duckx::duckx QXlsx jinja2cpp)
if (APPLE)
target_link_libraries(chat PRIVATE ${COCOA_LIBRARY})

View File

@ -11,3 +11,5 @@ add_subdirectory(DuckX)
set(QT_VERSION_MAJOR 6)
add_subdirectory(QXlsx/QXlsx)
add_subdirectory(Jinja2Cpp)

@ -0,0 +1 @@
Subproject commit b2a716798bfa63c7dae303fc1e272964c4e1f9ee

View File

@ -1,3 +1 @@
<svg width="32" height="32" viewBox="0 0 32 32" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M28.4138 9.17125L22.8288 3.585C22.643 3.39924 22.4225 3.25188 22.1799 3.15134C21.9372 3.0508 21.6771 2.99905 21.4144 2.99905C21.1517 2.99905 20.8916 3.0508 20.6489 3.15134C20.4062 3.25188 20.1857 3.39924 20 3.585L4.58626 19C4.39973 19.185 4.25185 19.4053 4.15121 19.648C4.05057 19.8907 3.99917 20.151 4.00001 20.4138V26C4.00001 26.5304 4.21072 27.0391 4.5858 27.4142C4.96087 27.7893 5.46958 28 6.00001 28H11.5863C11.849 28.0008 12.1093 27.9494 12.352 27.8488C12.5947 27.7482 12.815 27.6003 13 27.4138L28.4138 12C28.5995 11.8143 28.7469 11.5938 28.8474 11.3511C28.948 11.1084 28.9997 10.8483 28.9997 10.5856C28.9997 10.3229 28.948 10.0628 28.8474 9.82015C28.7469 9.57747 28.5995 9.35698 28.4138 9.17125ZM6.41376 20L17 9.41375L19.0863 11.5L8.50001 22.085L6.41376 20ZM6.00001 22.4138L9.58626 26H6.00001V22.4138ZM12 25.5863L9.91376 23.5L20.5 12.9138L22.5863 15L12 25.5863ZM24 13.5863L18.4138 8L21.4138 5L27 10.585L24 13.5863Z" fill="black"/>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" fill="#000000" viewBox="0 0 256 256"><path d="M227.31,73.37,182.63,28.68a16,16,0,0,0-22.63,0L36.69,152A15.86,15.86,0,0,0,32,163.31V208a16,16,0,0,0,16,16H92.69A15.86,15.86,0,0,0,104,219.31L227.31,96a16,16,0,0,0,0-22.63ZM92.69,208H48V163.31l88-88L180.69,120ZM192,108.68,147.31,64l24-24L216,84.68Z"></path></svg>

Before

(image error) Size: 1.0 KiB

After

(image error) Size: 372 B

View File

@ -29,7 +29,8 @@
"description": "<ul><li>Fast responses</li><li>Instruct model</li><li>Multilingual dialogue use</li><li>Agentic system capable</li><li>Trained by Meta</li><li>License: <a href=\"https://llama.meta.com/llama3_2/license/\">Meta Llama 3.2 Community License</a></li></ul>",
"url": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_0.gguf",
"promptTemplate": "<|start_header_id|>user<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2",
"systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>"
"systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>",
"chatTemplate": "{{- bos_token }}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{%- for message in messages %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
},
{
"order": "c",
@ -45,7 +46,8 @@
"description": "<ul><li>Fast responses</li><li>Instruct model</li><li>Multilingual dialogue use</li><li>Agentic system capable</li><li>Trained by Meta</li><li>License: <a href=\"https://llama.meta.com/llama3_2/license/\">Meta Llama 3.2 Community License</a></li></ul>",
"url": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf",
"promptTemplate": "<|start_header_id|>user<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2",
"systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>"
"systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>",
"chatTemplate": "{{- bos_token }}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{%- for message in messages %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
},
{
"order": "d",
@ -77,7 +79,8 @@
"systemPrompt": "",
"description": "<strong>Strong overall fast instruction following model</strong><br><ul><li>Fast responses</li><li>Trained by Mistral AI<li>Uncensored</li><li>Licensed for commercial use</li></ul>",
"url": "https://gpt4all.io/models/gguf/mistral-7b-instruct-v0.1.Q4_0.gguf",
"promptTemplate": "[INST] %1 [/INST]"
"promptTemplate": "[INST] %1 [/INST]",
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_start = 1 %}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if (message['role'] == 'user') != ((loop.index0 - loop_start) % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}"
},
{
"order": "f",
@ -125,7 +128,8 @@
"systemPrompt": "",
"description": "<strong>Very fast model with good quality</strong><br><ul><li>Fastest responses</li><li>Instruction based</li><li>Trained by TII<li>Finetuned by Nomic AI<li>Licensed for commercial use</ul>",
"url": "https://gpt4all.io/models/gguf/gpt4all-falcon-newbpe-q4_0.gguf",
"promptTemplate": "### Instruction:\n%1\n\n### Response:\n"
"promptTemplate": "### Instruction:\n%1\n\n### Response:\n",
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### User: ' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Assistant: ' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Assistant:' }}\n{%- endif %}"
},
{
"order": "i",
@ -140,7 +144,8 @@
"type": "LLaMA2",
"systemPrompt": "",
"description": "<ul><li>Instruction based<li>Trained by Microsoft<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/orca-2-7b.Q4_0.gguf"
"url": "https://gpt4all.io/models/gguf/orca-2-7b.Q4_0.gguf",
"chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}"
},
{
"order": "j",
@ -155,7 +160,8 @@
"type": "LLaMA2",
"systemPrompt": "",
"description": "<ul><li>Instruction based<li>Trained by Microsoft<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/orca-2-13b.Q4_0.gguf"
"url": "https://gpt4all.io/models/gguf/orca-2-13b.Q4_0.gguf",
"chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}"
},
{
"order": "k",
@ -170,7 +176,9 @@
"type": "LLaMA2",
"systemPrompt": "",
"description": "<strong>Strong overall larger model</strong><br><ul><li>Instruction based<li>Gives very long responses<li>Finetuned with only 1k of high-quality data<li>Trained by Microsoft and Peking University<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/wizardlm-13b-v1.2.Q4_0.gguf"
"url": "https://gpt4all.io/models/gguf/wizardlm-13b-v1.2.Q4_0.gguf",
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + ' ' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in loop_messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- 'USER: ' + message['content'] }}\n {%- elif message['role'] == 'assistant' %}\n {{- 'ASSISTANT: ' + message['content'] }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- if (loop.index0 - loop_start) % 2 == 0 %}\n {{- ' ' }}\n {%- else %}\n {{- eos_token }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- 'ASSISTANT:' }}\n{%- endif %}",
"systemMessage": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
},
{
"order": "l",
@ -186,7 +194,8 @@
"description": "<strong>Ghost 7B v0.9.1</strong> fast, powerful and smooth for Vietnamese and English languages.",
"url": "https://huggingface.co/lamhieu/ghost-7b-v0.9.1-gguf/resolve/main/ghost-7b-v0.9.1-Q4_0.gguf",
"promptTemplate": "<|user|>\n%1</s>\n<|assistant|>\n%2</s>\n",
"systemPrompt": "<|system|>\nYou are Ghost created by Lam Hieu. You are a helpful and knowledgeable assistant. You like to help and always give honest information, in its original language. In communication, you are always respectful, equal and promote positive behavior.\n</s>"
"systemPrompt": "<|system|>\nYou are Ghost created by Lam Hieu. You are a helpful and knowledgeable assistant. You like to help and always give honest information, in its original language. In communication, you are always respectful, equal and promote positive behavior.\n</s>",
"systemMessage": "You are Ghost created by Lam Hieu. You are a helpful and knowledgeable assistant. You like to help and always give honest information, in its original language. In communication, you are always respectful, equal and promote positive behavior."
},
{
"order": "m",
@ -202,7 +211,8 @@
"systemPrompt": "",
"description": "<strong>Extremely good model</strong><br><ul><li>Instruction based<li>Gives long responses<li>Curated with 300,000 uncensored instructions<li>Trained by Nous Research<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/nous-hermes-llama2-13b.Q4_0.gguf",
"promptTemplate": "### Instruction:\n%1\n\n### Response:\n"
"promptTemplate": "### Instruction:\n%1\n\n### Response:\n",
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### Instruction:\\n' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Response:\\n' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Instruction:\\n' }}\n{%- endif %}"
},
{
"order": "n",
@ -217,7 +227,9 @@
"type": "LLaMA",
"systemPrompt": "",
"description": "<strong>Very good overall model</strong><br><ul><li>Instruction based<li>Based on the same dataset as Groovy<li>Slower than Groovy, with higher quality responses<li>Trained by Nomic AI<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/gpt4all-13b-snoozy-q4_0.gguf"
"url": "https://gpt4all.io/models/gguf/gpt4all-13b-snoozy-q4_0.gguf",
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### Instruction:\\n' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Response:\\n' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Response:\\n' }}\n{%- endif %}",
"systemMessage": "Below is an instruction that describes a task. Write a response that appropriately completes the request."
},
{
"order": "o",
@ -234,7 +246,8 @@
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat-newbpe-q4_0.gguf",
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n"
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n",
"chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}"
},
{
"order": "p",
@ -250,7 +263,8 @@
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat.gguf4.Q4_0.gguf",
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n"
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n",
"chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}"
},
{
"order": "q",
@ -282,7 +296,8 @@
"description": "<strong>Small version of new model with novel dataset</strong><br><ul><li>Very fast responses</li><li>Instruction based</li><li>Explain tuned datasets</li><li>Orca Research Paper dataset construction approaches</li><li>Cannot be used commercially</li></ul>",
"url": "https://gpt4all.io/models/gguf/orca-mini-3b-gguf2-q4_0.gguf",
"promptTemplate": "### User:\n%1\n\n### Response:\n",
"systemPrompt": "### System:\nYou are an AI assistant that follows instruction extremely well. Help as much as you can.\n\n"
"systemPrompt": "### System:\nYou are an AI assistant that follows instruction extremely well. Help as much as you can.\n\n",
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- '### System:\\n' + messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### User:\\n' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Response:\\n' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Response:\\n' }}\n{%- endif %}"
},
{
"order": "s",
@ -299,7 +314,8 @@
"systemPrompt": "",
"promptTemplate": "%1",
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>Licensed for commercial use<li>WARNING: Not available for chat GUI</ul>",
"url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-newbpe-q4_0.gguf"
"url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-newbpe-q4_0.gguf",
"chatTemplate": null
},
{
"order": "t",
@ -316,7 +332,8 @@
"systemPrompt": "",
"promptTemplate": "%1",
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</ul>",
"url": "https://gpt4all.io/models/gguf/starcoder-newbpe-q4_0.gguf"
"url": "https://gpt4all.io/models/gguf/starcoder-newbpe-q4_0.gguf",
"chatTemplate": null
},
{
"order": "u",
@ -333,7 +350,8 @@
"systemPrompt": "",
"promptTemplate": "%1",
"description": "<strong>Trained on collection of Python and TypeScript</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</li>",
"url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf"
"url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf",
"chatTemplate": null
},
{
"order": "v",
@ -351,7 +369,8 @@
"embeddingModel": true,
"systemPrompt": "",
"description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf"
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf",
"chatTemplate": null
},
{
"order": "w",
@ -367,7 +386,8 @@
"type": "Bert",
"embeddingModel": true,
"description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2.gguf2.f16.gguf"
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2.gguf2.f16.gguf",
"chatTemplate": null
},
{
"order": "x",
@ -383,7 +403,9 @@
"description": "<strong>Mistral-based model for German-language applications</strong><br><ul><li>Fast responses</li><li>Chat based model</li><li>Trained by ellamind<li>Finetuned on German instruction and chat data</a><li>Licensed for commercial use</ul>",
"url": "https://huggingface.co/TheBloke/em_german_mistral_v01-GGUF/resolve/main/em_german_mistral_v01.Q4_0.gguf",
"promptTemplate": "USER: %1 ASSISTANT: ",
"systemPrompt": "Du bist ein hilfreicher Assistent. "
"systemPrompt": "Du bist ein hilfreicher Assistent. ",
"chatTemplate": "{%- set system_message = false %}\n{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {%- set system_message = true %}\n {{- messages[0]['content'] }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if (not loop.first) or (system_message is not none) %}\n {{- ' ' }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {{- 'USER: ' + message['content'] }}\n {%- elif message['role'] == 'assistant' %}\n {{- 'ASSISTANT: ' + message['content'] }}\n {%- else %}\n {{- raise_exception('After the optional system message, conversation roles must be either user or assistant.') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {%- if messages %}\n {{- ' ' }}\n {%- endif %}\n {{- 'ASSISTANT:' }}\n{%- endif %}",
"systemMessage": "Du bist ein hilfreicher Assistent."
},
{
"order": "y",
@ -400,7 +422,8 @@
"embeddingModel": true,
"systemPrompt": "",
"description": "nomic-embed-text-v1",
"url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.f16.gguf"
"url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.f16.gguf",
"chatTemplate": null
},
{
"order": "z",
@ -417,7 +440,8 @@
"embeddingModel": true,
"systemPrompt": "",
"description": "nomic-embed-text-v1.5",
"url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.5.f16.gguf"
"url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.5.f16.gguf",
"chatTemplate": null
},
{
"order": "zzz",

View File

@ -10,7 +10,7 @@ import network
import llm
MySettingsTab {
onRestoreDefaultsClicked: {
onRestoreDefaults: {
MySettings.restoreApplicationDefaults();
}
title: qsTr("Application")
@ -486,23 +486,6 @@ MySettingsTab {
Accessible.name: nThreadsLabel.text
Accessible.description: ToolTip.text
}
MySettingsLabel {
id: saveChatsContextLabel
text: qsTr("Save Chat Context")
helpText: qsTr("Save the chat model's state to disk for faster loading. WARNING: Uses ~2GB per chat.")
Layout.row: 12
Layout.column: 0
}
MyCheckBox {
id: saveChatsContextBox
Layout.row: 12
Layout.column: 2
Layout.alignment: Qt.AlignRight
checked: MySettings.saveChatsContext
onClicked: {
MySettings.saveChatsContext = !MySettings.saveChatsContext
}
}
MySettingsLabel {
id: trayLabel
text: qsTr("Enable System Tray")

View File

@ -8,8 +8,23 @@ import QtQuick.Layouts
import gpt4all
import mysettings
ColumnLayout {
property var inputBoxText: null
signal setInputBoxText(text: string)
Item {
Layout.fillWidth: true
Layout.maximumWidth: parent.width
Layout.preferredHeight: gridLayout.height
HoverHandler { id: hoverArea }
GridLayout {
rows: 5
id: gridLayout
anchors.left: parent.left
anchors.right: parent.right
columns: 2
Item {
@ -40,7 +55,7 @@ GridLayout {
to: 360
duration: 1000
loops: Animation.Infinite
running: currentResponse && (currentChat.responseInProgress || currentChat.restoringFromText)
running: isCurrentResponse && currentChat.responseInProgress
}
}
}
@ -73,13 +88,11 @@ GridLayout {
color: theme.mutedTextColor
}
RowLayout {
visible: currentResponse && ((value === "" && currentChat.responseInProgress) || currentChat.restoringFromText)
visible: isCurrentResponse && (value === "" && currentChat.responseInProgress)
Text {
color: theme.mutedTextColor
font.pixelSize: theme.fontSizeLarger
text: {
if (currentChat.restoringFromText)
return qsTr("restoring from text ...");
switch (currentChat.responseState) {
case Chat.ResponseStopped: return qsTr("response stopped ...");
case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: %1 ...").arg(currentChat.collectionList.join(", "));
@ -99,10 +112,11 @@ GridLayout {
Layout.row: 1
Layout.column: 1
Layout.fillWidth: true
spacing: 20
spacing: 10
Flow {
id: attachedUrlsFlow
Layout.fillWidth: true
Layout.bottomMargin: 10
spacing: 10
visible: promptAttachments.length !== 0
Repeater {
@ -156,7 +170,7 @@ GridLayout {
focus: false
readOnly: true
font.pixelSize: theme.fontSizeLarge
cursorVisible: currentResponse ? currentChat.responseInProgress : false
cursorVisible: isCurrentResponse ? currentChat.responseInProgress : false
cursorPosition: text.length
TapHandler {
id: tapHandler
@ -183,12 +197,12 @@ GridLayout {
}
onLinkActivated: function(link) {
if (!currentResponse || !currentChat.responseInProgress)
if (!isCurrentResponse || !currentChat.responseInProgress)
Qt.openUrlExternally(link)
}
onLinkHovered: function (link) {
if (!currentResponse || !currentChat.responseInProgress)
if (!isCurrentResponse || !currentChat.responseInProgress)
statusBar.externalHoveredLink = link
}
@ -239,13 +253,19 @@ GridLayout {
textProcessor.setValue(value);
}
property bool textProcessorReady: false
Component.onCompleted: {
resetChatViewTextProcessor();
chatModel.valueChanged.connect(function(i, value) {
if (index === i)
textProcessorReady = true;
}
Connections {
target: chatModel
function onValueChanged(i, value) {
if (myTextArea.textProcessorReady && index === i)
textProcessor.setValue(value);
}
);
}
Connections {
@ -282,67 +302,6 @@ GridLayout {
Network.sendConversation(currentChat.id, getConversationJson());
}
}
Column {
Layout.alignment: Qt.AlignRight
Layout.rightMargin: 15
visible: name === "Response: " &&
(!currentResponse || !currentChat.responseInProgress) && MySettings.networkIsActive
spacing: 10
Item {
width: childrenRect.width
height: childrenRect.height
MyToolButton {
id: thumbsUp
width: 24
height: 24
imageWidth: width
imageHeight: height
opacity: thumbsUpState || thumbsUpState == thumbsDownState ? 1.0 : 0.2
source: "qrc:/gpt4all/icons/thumbs_up.svg"
Accessible.name: qsTr("Thumbs up")
Accessible.description: qsTr("Gives a thumbs up to the response")
onClicked: {
if (thumbsUpState && !thumbsDownState)
return
chatModel.updateNewResponse(index, "")
chatModel.updateThumbsUpState(index, true)
chatModel.updateThumbsDownState(index, false)
Network.sendConversation(currentChat.id, getConversationJson());
}
}
MyToolButton {
id: thumbsDown
anchors.top: thumbsUp.top
anchors.topMargin: 3
anchors.left: thumbsUp.right
anchors.leftMargin: 3
width: 24
height: 24
imageWidth: width
imageHeight: height
checked: thumbsDownState
opacity: thumbsDownState || thumbsUpState == thumbsDownState ? 1.0 : 0.2
transform: [
Matrix4x4 {
matrix: Qt.matrix4x4(-1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1)
},
Translate {
x: thumbsDown.width
}
]
source: "qrc:/gpt4all/icons/thumbs_down.svg"
Accessible.name: qsTr("Thumbs down")
Accessible.description: qsTr("Opens thumbs down dialog")
onClicked: {
thumbsDownDialog.open()
}
}
}
}
}
Item {
@ -353,11 +312,13 @@ GridLayout {
Layout.preferredWidth: childrenRect.width
Layout.preferredHeight: childrenRect.height
visible: {
if (name !== "Response: ")
return false
if (consolidatedSources.length === 0)
return false
if (!MySettings.localDocsShowReferences)
return false
if (currentResponse && currentChat.responseInProgress
if (isCurrentResponse && currentChat.responseInProgress
&& currentChat.responseState !== Chat.GeneratingQuestions )
return false
return true
@ -443,7 +404,7 @@ GridLayout {
return false
if (!MySettings.localDocsShowReferences)
return false
if (currentResponse && currentChat.responseInProgress
if (isCurrentResponse && currentChat.responseInProgress
&& currentChat.responseState !== Chat.GeneratingQuestions )
return false
return true
@ -566,8 +527,139 @@ GridLayout {
}
}
ConfirmationDialog {
id: editPromptDialog
dialogTitle: qsTr("Edit this prompt?")
description: qsTr("The existing response and all later messages will be permanently erased.")
onAccepted: {
const msg = currentChat.popPrompt(index);
if (msg !== null)
setInputBoxText(msg);
}
}
ConfirmationDialog {
id: redoResponseDialog
dialogTitle: qsTr("Redo this response?")
description: qsTr("The existing response and all later messages will be permanently erased.")
onAccepted: currentChat.regenerateResponse(index)
}
RowLayout {
id: buttonRow
Layout.row: 4
Layout.column: 1
Layout.maximumWidth: parent.width
Layout.fillWidth: false
Layout.alignment: Qt.AlignLeft | Qt.AlignTop
spacing: 3
visible: !isCurrentResponse || !currentChat.responseInProgress
enabled: opacity > 0
opacity: hoverArea.hovered
readonly property var canModify: !currentChat.isServer && currentChat.isModelLoaded && !currentChat.responseInProgress
Behavior on opacity {
OpacityAnimator { duration: 30 }
}
ChatMessageButton {
visible: parent.canModify && model.name === "Prompt: "
Layout.maximumWidth: 24
Layout.maximumHeight: 24
Layout.alignment: Qt.AlignVCenter
Layout.fillWidth: false
source: "qrc:/gpt4all/icons/edit.svg"
onClicked: {
if (inputBoxText === "")
editPromptDialog.open();
}
name: qsTr("Edit")
}
ChatMessageButton {
visible: parent.canModify && model.name === "Response: "
Layout.maximumWidth: 24
Layout.maximumHeight: 24
Layout.alignment: Qt.AlignVCenter
Layout.fillWidth: false
name: qsTr("Redo")
source: "qrc:/gpt4all/icons/regenerate.svg"
onClicked: redoResponseDialog.open()
}
ChatMessageButton {
Layout.maximumWidth: 24
Layout.maximumHeight: 24
Layout.alignment: Qt.AlignVCenter
Layout.fillWidth: false
name: qsTr("Copy")
source: "qrc:/gpt4all/icons/copy.svg"
onClicked: {
myTextArea.selectAll();
myTextArea.copy();
myTextArea.deselect();
}
}
Item {
visible: name === "Response: " && MySettings.networkIsActive
Layout.alignment: Qt.AlignVCenter
Layout.preferredWidth: childrenRect.width
Layout.preferredHeight: childrenRect.height
Layout.fillWidth: false
ChatMessageButton {
id: thumbsUp
anchors.left: parent.left
anchors.verticalCenter: parent.verticalCenter
opacity: thumbsUpState || thumbsUpState == thumbsDownState ? 1.0 : 0.2
source: "qrc:/gpt4all/icons/thumbs_up.svg"
name: qsTr("Like response")
onClicked: {
if (thumbsUpState && !thumbsDownState)
return
chatModel.updateNewResponse(index, "")
chatModel.updateThumbsUpState(index, true)
chatModel.updateThumbsDownState(index, false)
Network.sendConversation(currentChat.id, getConversationJson());
}
}
ChatMessageButton {
id: thumbsDown
anchors.top: thumbsUp.top
anchors.topMargin: buttonRow.spacing
anchors.left: thumbsUp.right
anchors.leftMargin: buttonRow.spacing
checked: thumbsDownState
opacity: thumbsDownState || thumbsUpState == thumbsDownState ? 1.0 : 0.2
bgTransform: [
Matrix4x4 {
matrix: Qt.matrix4x4(-1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1)
},
Translate {
x: thumbsDown.width
}
]
source: "qrc:/gpt4all/icons/thumbs_down.svg"
name: qsTr("Dislike response")
onClicked: {
thumbsDownDialog.open()
}
}
}
}
} // GridLayout
} // Item
GridLayout {
Layout.fillWidth: true
Layout.maximumWidth: parent.width
function shouldShowSuggestions() {
if (!currentResponse)
if (!isCurrentResponse)
return false;
if (MySettings.suggestionMode === 2) // Off
return false;
@ -577,8 +669,8 @@ GridLayout {
}
Item {
visible: shouldShowSuggestions()
Layout.row: 4
visible: parent.shouldShowSuggestions()
Layout.row: 5
Layout.column: 0
Layout.topMargin: 20
Layout.alignment: Qt.AlignVCenter | Qt.AlignRight
@ -601,8 +693,8 @@ GridLayout {
}
Item {
visible: shouldShowSuggestions()
Layout.row: 4
visible: parent.shouldShowSuggestions()
Layout.row: 5
Layout.column: 1
Layout.topMargin: 20
Layout.fillWidth: true
@ -627,8 +719,8 @@ GridLayout {
}
ColumnLayout {
visible: shouldShowSuggestions()
Layout.row: 5
visible: parent.shouldShowSuggestions()
Layout.row: 6
Layout.column: 1
Layout.fillWidth: true
Layout.minimumHeight: 1
@ -786,4 +878,7 @@ GridLayout {
}
}
}
}
} // GridLayout
} // ColumnLayout

View File

@ -0,0 +1,20 @@
import QtQuick
import QtQuick.Controls
import gpt4all
MyToolButton {
property string name
width: 24
height: 24
imageWidth: width
imageHeight: height
ToolTip {
visible: parent.hovered
y: parent.height * 1.5
text: name
delay: Qt.styleHints.mousePressAndHoldInterval
}
Accessible.name: name
}

View File

@ -24,6 +24,12 @@ Rectangle {
property var currentChat: ChatListModel.currentChat
property var chatModel: currentChat.chatModel
property var currentModelInfo: currentChat && currentChat.modelInfo
property var currentModelId: null
onCurrentModelInfoChanged: {
const newId = currentModelInfo && currentModelInfo.id;
if (currentModelId !== newId) { currentModelId = newId; }
}
signal addCollectionViewRequested()
signal addModelViewRequested()
@ -79,14 +85,11 @@ Rectangle {
function open_(msg) { message = msg; open(); }
}
SwitchModelDialog {
ConfirmationDialog {
id: switchModelDialog
anchors.centerIn: parent
Item {
Accessible.role: Accessible.Dialog
Accessible.name: qsTr("Switch model dialog")
Accessible.description: qsTr("Warn the user if they switch models, then context will be erased")
}
property int index: -1
dialogTitle: qsTr("Erase conversation?")
description: qsTr("Changing the model will erase the current conversation.")
}
PopupDialog {
@ -103,6 +106,16 @@ Rectangle {
font.pixelSize: theme.fontSizeLarge
}
ConfirmationDialog {
id: resetContextDialog
dialogTitle: qsTr("Erase conversation?")
description: qsTr("The entire chat will be erased.")
onAccepted: {
Network.trackChatEvent("reset_context", { "length": chatModel.count });
currentChat.reset();
}
}
function getConversation() {
var conversation = "";
for (var i = 0; i < chatModel.count; i++) {
@ -703,7 +716,7 @@ Rectangle {
if (i !== -1) {
defaultModel = comboBox.valueAt(i);
} else {
defaultModel = comboBox.valueAt(0);
defaultModel = comboBox.count ? comboBox.valueAt(0) : "";
}
if (defaultModel !== "") {
defaultModelName = ModelList.modelInfo(defaultModel).name;
@ -790,9 +803,9 @@ Rectangle {
Layout.leftMargin: 50
Layout.rightMargin: 50
Layout.alignment: Qt.AlignHCenter
spacing: 25
spacing: 10
model: chatModel
cacheBuffer: Math.max(0, listView.contentHeight)
cacheBuffer: 2147483647
ScrollBar.vertical: ScrollBar {
policy: ScrollBar.AsNeeded
@ -804,6 +817,12 @@ Rectangle {
delegate: ChatItemView {
width: listView.contentItem.width - 15
inputBoxText: textInput.text
onSetInputBoxText: text => {
textInput.text = text;
textInput.forceActiveFocus();
textInput.cursorPosition = text.length;
}
}
function scrollToEnd() {
@ -832,11 +851,9 @@ Rectangle {
clip: true
z: 400
property bool isHovered: {
return conversationTrayButton.isHovered ||
resetContextButton.hovered || copyChatButton.hovered ||
regenerateButton.hovered
}
property bool isHovered: (
conversationTrayButton.isHovered || resetContextButton.hovered || copyChatButton.hovered
)
state: conversationTrayContent.isHovered ? "expanded" : "collapsed"
states: [
@ -892,11 +909,7 @@ Rectangle {
source: "qrc:/gpt4all/icons/recycle.svg"
imageWidth: 20
imageHeight: 20
onClicked: {
Network.trackChatEvent("reset_context", { "length": chatModel.count })
currentChat.reset();
currentChat.processSystemPrompt();
}
onClicked: resetContextDialog.open()
ToolTip.visible: resetContextButton.hovered
ToolTip.text: qsTr("Erase and reset chat session")
}
@ -921,34 +934,6 @@ Rectangle {
ToolTip.visible: copyChatButton.hovered
ToolTip.text: qsTr("Copy chat session to clipboard")
}
MyToolButton {
id: regenerateButton
Layout.preferredWidth: 40
Layout.preferredHeight: 40
source: "qrc:/gpt4all/icons/regenerate.svg"
imageWidth: 20
imageHeight: 20
visible: chatModel.count && !currentChat.isServer && currentChat.isModelLoaded && !currentChat.responseInProgress
onClicked: {
if (chatModel.count < 2)
return
var promptIndex = chatModel.count - 2
var promptElement = chatModel.get(promptIndex)
var responseIndex = chatModel.count - 1
var responseElement = chatModel.get(responseIndex)
if (promptElement.name !== "Prompt: " || responseElement.name !== "Response: ")
return
currentChat.regenerateResponse()
chatModel.updateCurrentResponse(responseIndex, true)
chatModel.updateStopped(responseIndex, false)
chatModel.updateThumbsUpState(responseIndex, false)
chatModel.updateThumbsDownState(responseIndex, false)
chatModel.updateNewResponse(responseIndex, "")
currentChat.prompt(promptElement.promptPlusAttachments)
}
ToolTip.visible: regenerateButton.hovered
ToolTip.text: qsTr("Redo last chat response")
}
}
}
@ -1026,13 +1011,15 @@ Rectangle {
anchors.leftMargin: 30
horizontalAlignment: Qt.AlignRight
verticalAlignment: Qt.AlignVCenter
color: theme.mutedTextColor
visible: currentChat.tokenSpeed !== "" || externalHoveredLink !== ""
color: textInputView.error !== null ? theme.textErrorColor : theme.mutedTextColor
visible: currentChat.tokenSpeed !== "" || externalHoveredLink !== "" || textInputView.error !== null
elide: Text.ElideRight
wrapMode: Text.WordWrap
text: {
if (externalHoveredLink !== "")
return externalHoveredLink
if (textInputView.error !== null)
return textInputView.error;
const segments = [currentChat.tokenSpeed];
const device = currentChat.device;
@ -1050,6 +1037,7 @@ Rectangle {
}
font.pixelSize: theme.fontSizeSmaller
font.bold: true
onLinkActivated: function(link) { Qt.openUrlExternally(link) }
}
RectangularGlow {
@ -1079,8 +1067,8 @@ Rectangle {
Rectangle {
id: textInputView
color: theme.controlBackground
border.width: 1
border.color: theme.controlBorder
border.width: error === null ? 1 : 2
border.color: error === null ? theme.controlBorder : theme.textErrorColor
radius: 10
anchors.left: parent.left
anchors.right: parent.right
@ -1091,6 +1079,41 @@ Rectangle {
height: textInputViewLayout.implicitHeight
visible: !currentChat.isServer && ModelList.selectableModels.count !== 0
property var error: null
function checkError() {
const info = currentModelInfo;
if (info === null || !info.id) {
error = null;
} else if (info.chatTemplate.isLegacy) {
error = qsTr("Legacy prompt template needs to be " +
"<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">updated" +
"</a> in Settings.");
} else if (!info.chatTemplate.isSet) {
error = qsTr("No <a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
"chat template</a> configured.");
} else if (/^\s*$/.test(info.chatTemplate.value)) {
error = qsTr("The <a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
"chat template</a> cannot be blank.");
} else if (info.systemMessage.isLegacy) {
error = qsTr("Legacy system prompt needs to be " +
"<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">updated" +
"</a> in Settings.");
} else
error = null;
}
Component.onCompleted: checkError()
Connections {
target: window
function onCurrentModelIdChanged() { textInputView.checkError(); }
}
Connections {
target: MySettings
function onChatTemplateChanged(info)
{ if (info.id === window.currentModelId) textInputView.checkError(); }
function onSystemMessageChanged(info)
{ if (info.id === window.currentModelId) textInputView.checkError(); }
}
MouseArea {
id: textInputViewMouseArea
anchors.fill: parent
@ -1214,16 +1237,16 @@ Rectangle {
Accessible.role: Accessible.EditableText
Accessible.name: placeholderText
Accessible.description: qsTr("Send messages/prompts to the model")
Keys.onReturnPressed: (event)=> {
if (event.modifiers & Qt.ControlModifier || event.modifiers & Qt.ShiftModifier)
event.accepted = false;
else {
editingFinished();
sendMessage()
}
}
Keys.onReturnPressed: event => {
if (event.modifiers & Qt.ControlModifier || event.modifiers & Qt.ShiftModifier) {
event.accepted = false;
} else if (!chatModel.hasError && textInputView.error === null) {
editingFinished();
sendMessage();
}
}
function sendMessage() {
if ((textInput.text === "" && attachmentModel.count === 0) || currentChat.responseInProgress || currentChat.restoringFromText)
if ((textInput.text === "" && attachmentModel.count === 0) || currentChat.responseInProgress)
return
currentChat.stopGenerating()
@ -1338,6 +1361,7 @@ Rectangle {
imageWidth: theme.fontSizeLargest
imageHeight: theme.fontSizeLargest
visible: !currentChat.responseInProgress && !currentChat.isServer && ModelList.selectableModels.count !== 0
enabled: !chatModel.hasError && textInputView.error === null
source: "qrc:/gpt4all/icons/send_message.svg"
Accessible.name: qsTr("Send message")
Accessible.description: qsTr("Sends the message/prompt contained in textfield to the model")

View File

@ -0,0 +1,59 @@
import QtCore
import QtQuick
import QtQuick.Controls
import QtQuick.Controls.Basic
import QtQuick.Layouts
MyDialog {
id: confirmationDialog
anchors.centerIn: parent
modal: true
padding: 20
property alias dialogTitle: titleText.text
property alias description: descriptionText.text
Theme { id: theme }
contentItem: ColumnLayout {
Text {
id: titleText
Layout.alignment: Qt.AlignHCenter
textFormat: Text.StyledText
color: theme.textColor
font.pixelSize: theme.fontSizeLarger
font.bold: true
}
Text {
id: descriptionText
Layout.alignment: Qt.AlignHCenter
textFormat: Text.StyledText
color: theme.textColor
font.pixelSize: theme.fontSizeMedium
}
}
footer: DialogButtonBox {
id: dialogBox
padding: 20
alignment: Qt.AlignRight
spacing: 10
MySettingsButton {
text: qsTr("OK")
textColor: theme.mediumButtonText
backgroundColor: theme.mediumButtonBackground
backgroundColorHovered: theme.mediumButtonBackgroundHovered
DialogButtonBox.buttonRole: DialogButtonBox.AcceptRole
}
MySettingsButton {
text: qsTr("Cancel")
DialogButtonBox.buttonRole: DialogButtonBox.RejectRole
}
background: Rectangle {
color: "transparent"
}
Keys.onEnterPressed: confirmationDialog.accept()
Keys.onReturnPressed: confirmationDialog.accept()
}
Component.onCompleted: dialogBox.forceActiveFocus()
}

View File

@ -10,7 +10,7 @@ import mysettings
import network
MySettingsTab {
onRestoreDefaultsClicked: {
onRestoreDefaults: {
MySettings.restoreLocalDocsDefaults();
}

View File

@ -8,10 +8,34 @@ import mysettings
import chatlistmodel
MySettingsTab {
onRestoreDefaultsClicked: {
onRestoreDefaults: {
MySettings.restoreModelDefaults(root.currentModelInfo);
}
title: qsTr("Model")
ConfirmationDialog {
id: resetSystemMessageDialog
property var index: null
property bool resetClears: false
dialogTitle: qsTr("%1 system message?").arg(resetClears ? qsTr("Clear") : qsTr("Reset"))
description: qsTr("The system message will be %1.").arg(resetClears ? qsTr("removed") : qsTr("reset to the default"))
onAccepted: MySettings.resetModelSystemMessage(ModelList.modelInfo(index))
function show(index_, resetClears_) { index = index_; resetClears = resetClears_; open(); }
}
ConfirmationDialog {
id: resetChatTemplateDialog
property bool resetClears: false
property var index: null
dialogTitle: qsTr("%1 chat template?").arg(resetClears ? qsTr("Clear") : qsTr("Reset"))
description: qsTr("The chat template will be %1.").arg(resetClears ? qsTr("erased") : qsTr("reset to the default"))
onAccepted: {
MySettings.resetModelChatTemplate(ModelList.modelInfo(index));
templateTextArea.resetText();
}
function show(index_, resetClears_) { index = index_; resetClears = resetClears_; open(); }
}
contentItem: GridLayout {
id: root
columns: 3
@ -35,6 +59,7 @@ MySettingsTab {
RowLayout {
Layout.fillWidth: true
Layout.maximumWidth: parent.width
Layout.row: 2
Layout.column: 0
Layout.columnSpan: 2
@ -153,69 +178,154 @@ MySettingsTab {
Layout.fillWidth: true
}
MySettingsLabel {
visible: !root.currentModelInfo.isOnline
text: qsTr("System Prompt")
helpText: qsTr("Prefixed at the beginning of every conversation. Must contain the appropriate framing tokens.")
RowLayout {
Layout.row: 7
Layout.column: 0
Layout.columnSpan: 2
Layout.topMargin: 15
Layout.fillWidth: true
Layout.maximumWidth: parent.width
spacing: 10
MySettingsLabel {
id: systemMessageLabel
text: qsTr("System Message")
helpText: qsTr("A message to set the context or guide the behavior of the model. Leave blank for " +
"none. NOTE: Since GPT4All 3.5, this should not contain control tokens.")
onReset: () => resetSystemMessageDialog.show(root.currentModelId, resetClears)
function updateResetButton() {
const info = root.currentModelInfo;
// NOTE: checks if the *override* is set, regardless of whether there is a default
canReset = !!info.id && MySettings.isModelSystemMessageSet(info);
resetClears = !info.defaultSystemMessage;
}
Component.onCompleted: updateResetButton()
Connections {
target: root
function onCurrentModelIdChanged() { systemMessageLabel.updateResetButton(); }
}
Connections {
target: MySettings
function onSystemMessageChanged(info)
{ if (info.id === root.currentModelId) systemMessageLabel.updateResetButton(); }
}
}
Label {
id: systemMessageLabelHelp
visible: systemMessageArea.errState !== "ok"
Layout.alignment: Qt.AlignBottom
Layout.fillWidth: true
Layout.rightMargin: 5
Layout.maximumHeight: systemMessageLabel.height
text: qsTr("System message is not " +
"<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">plain text</a>.")
color: systemMessageArea.errState === "error" ? theme.textErrorColor : theme.textWarningColor
font.pixelSize: theme.fontSizeLarger
font.bold: true
wrapMode: Text.Wrap
elide: Text.ElideRight
onLinkActivated: function(link) { Qt.openUrlExternally(link) }
}
}
Rectangle {
id: systemPrompt
visible: !root.currentModelInfo.isOnline
id: systemMessage
Layout.row: 8
Layout.column: 0
Layout.columnSpan: 2
Layout.fillWidth: true
color: "transparent"
Layout.minimumHeight: Math.max(100, systemPromptArea.contentHeight + 20)
Layout.minimumHeight: Math.max(100, systemMessageArea.contentHeight + 20)
MyTextArea {
id: systemPromptArea
id: systemMessageArea
anchors.fill: parent
text: root.currentModelInfo.systemPrompt
property bool isBeingReset: false
function resetText() {
const info = root.currentModelInfo;
isBeingReset = true;
text = (info.id ? info.systemMessage.value : null) ?? "";
isBeingReset = false;
}
Component.onCompleted: resetText()
Connections {
target: MySettings
function onSystemPromptChanged() {
systemPromptArea.text = root.currentModelInfo.systemPrompt;
}
function onSystemMessageChanged(info)
{ if (info.id === root.currentModelId) systemMessageArea.resetText(); }
}
Connections {
target: root
function onCurrentModelInfoChanged() {
systemPromptArea.text = root.currentModelInfo.systemPrompt;
}
function onCurrentModelIdChanged() { systemMessageArea.resetText(); }
}
// strict validation, because setModelSystemMessage clears isLegacy
readonly property var reLegacyCheck: (
/(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>/m
)
onTextChanged: {
MySettings.setModelSystemPrompt(root.currentModelInfo, text)
const info = root.currentModelInfo;
if (!info.id) {
errState = "ok";
} else if (info.systemMessage.isLegacy && (isBeingReset || reLegacyCheck.test(text))) {
errState = "error";
} else
errState = reLegacyCheck.test(text) ? "warning" : "ok";
if (info.id && errState !== "error" && !isBeingReset)
MySettings.setModelSystemMessage(info, text);
systemMessageLabel.updateResetButton();
}
Accessible.role: Accessible.EditableText
Accessible.name: systemMessageLabel.text
Accessible.description: systemMessageLabelHelp.text
}
}
RowLayout {
Layout.row: 9
Layout.column: 0
Layout.columnSpan: 2
Layout.topMargin: 15
Layout.fillWidth: true
Layout.maximumWidth: parent.width
spacing: 10
MySettingsLabel {
id: promptTemplateLabel
text: qsTr("Prompt Template")
helpText: qsTr("The template that wraps every prompt.")
id: chatTemplateLabel
text: qsTr("Chat Template")
helpText: qsTr("This Jinja template turns the chat into input for the model.")
onReset: () => resetChatTemplateDialog.show(root.currentModelId, resetClears)
function updateResetButton() {
const info = root.currentModelInfo;
canReset = !!info.id && (
MySettings.isModelChatTemplateSet(info)
|| templateTextArea.text !== (info.chatTemplate.value ?? "")
);
resetClears = !info.defaultChatTemplate;
}
Component.onCompleted: updateResetButton()
Connections {
target: root
function onCurrentModelIdChanged() { chatTemplateLabel.updateResetButton(); }
}
Connections {
target: MySettings
function onChatTemplateChanged(info)
{ if (info.id === root.currentModelId) chatTemplateLabel.updateResetButton(); }
}
}
MySettingsLabel {
id: promptTemplateLabelHelp
text: qsTr("Must contain the string \"%1\" to be replaced with the user's input.")
color: theme.textErrorColor
visible: templateTextArea.text.indexOf("%1") === -1
wrapMode: TextArea.Wrap
Label {
id: chatTemplateLabelHelp
visible: templateTextArea.errState !== "ok"
Layout.alignment: Qt.AlignBottom
Layout.fillWidth: true
Layout.rightMargin: 5
Layout.maximumHeight: chatTemplateLabel.height
text: templateTextArea.errMsg
color: templateTextArea.errState === "error" ? theme.textErrorColor : theme.textWarningColor
font.pixelSize: theme.fontSizeLarger
font.bold: true
wrapMode: Text.Wrap
elide: Text.ElideRight
onLinkActivated: function(link) { Qt.openUrlExternally(link) }
}
}
Rectangle {
id: promptTemplate
id: chatTemplate
Layout.row: 10
Layout.column: 0
Layout.columnSpan: 2
@ -226,27 +336,71 @@ MySettingsTab {
MyTextArea {
id: templateTextArea
anchors.fill: parent
text: root.currentModelInfo.promptTemplate
font: fixedFont
property bool isBeingReset: false
property var errMsg: null
function resetText() {
const info = root.currentModelInfo;
isBeingReset = true;
text = (info.id ? info.chatTemplate.value : null) ?? "";
isBeingReset = false;
}
Component.onCompleted: resetText()
Connections {
target: MySettings
function onPromptTemplateChanged() {
templateTextArea.text = root.currentModelInfo.promptTemplate;
}
function onChatTemplateChanged() { templateTextArea.resetText(); }
}
Connections {
target: root
function onCurrentModelInfoChanged() {
templateTextArea.text = root.currentModelInfo.promptTemplate;
}
function onCurrentModelIdChanged() { templateTextArea.resetText(); }
}
function legacyCheck() {
return /%[12]\b/.test(text) || !/\{%.*%\}.*\{\{.*\}\}.*\{%.*%\}/.test(text.replace(/\n/g, ''))
|| !/\bcontent\b/.test(text);
}
onTextChanged: {
if (templateTextArea.text.indexOf("%1") !== -1) {
MySettings.setModelPromptTemplate(root.currentModelInfo, text)
const info = root.currentModelInfo;
let jinjaError;
if (!info.id) {
errMsg = null;
errState = "ok";
} else if (info.chatTemplate.isLegacy && (isBeingReset || legacyCheck())) {
errMsg = null;
errState = "error";
} else if (text === "" && !info.chatTemplate.isSet) {
errMsg = qsTr("No <a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
"chat template</a> configured.");
errState = "error";
} else if (/^\s*$/.test(text)) {
errMsg = qsTr("The <a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
"chat template</a> cannot be blank.");
errState = "error";
} else if ((jinjaError = MySettings.checkJinjaTemplateError(text)) !== null) {
errMsg = qsTr("<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">Syntax" +
" error</a>: %1").arg(jinjaError);
errState = "error";
} else if (legacyCheck()) {
errMsg = qsTr("Chat template is not in " +
"<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
"Jinja format</a>.")
errState = "warning";
} else {
errState = "ok";
}
if (info.id && errState !== "error" && !isBeingReset)
MySettings.setModelChatTemplate(info, text);
chatTemplateLabel.updateResetButton();
}
Keys.onPressed: event => {
if (event.key === Qt.Key_Tab) {
const a = templateTextArea;
event.accepted = true; // suppress tab
a.insert(a.cursorPosition, ' '); // four spaces
}
}
Accessible.role: Accessible.EditableText
Accessible.name: promptTemplateLabel.text
Accessible.description: promptTemplateLabelHelp.text
Accessible.name: chatTemplateLabel.text
Accessible.description: chatTemplateLabelHelp.text
}
}

View File

@ -17,6 +17,7 @@ Button {
property color borderColor: "transparent"
property real fontPixelSize: theme.fontSizeLarge
property string toolTip
property alias backgroundRadius: background.radius
contentItem: Text {
text: myButton.text
@ -28,6 +29,7 @@ Button {
Accessible.name: text
}
background: Rectangle {
id: background
radius: 10
border.width: borderWidth
border.color: borderColor

View File

@ -17,13 +17,42 @@ ColumnLayout {
property alias color: mainTextLabel.color
property alias linkColor: mainTextLabel.linkColor
Label {
id: mainTextLabel
color: theme.settingsTitleTextColor
font.pixelSize: theme.fontSizeLarger
font.bold: true
onLinkActivated: function(link) {
root.linkActivated(link);
property var onReset: null
property alias canReset: resetButton.enabled
property bool resetClears: false
Item {
anchors.margins: 5
width: childrenRect.width
height: mainTextLabel.contentHeight
Label {
id: mainTextLabel
anchors.left: parent.left
anchors.top: parent.top
anchors.bottom: parent.bottom
color: theme.settingsTitleTextColor
font.pixelSize: theme.fontSizeLarger
font.bold: true
verticalAlignment: Text.AlignVCenter
onLinkActivated: function(link) {
root.linkActivated(link);
}
}
MySettingsButton {
id: resetButton
anchors.baseline: mainTextLabel.baseline
anchors.left: mainTextLabel.right
height: mainTextLabel.contentHeight
anchors.leftMargin: 10
padding: 2
leftPadding: 10
rightPadding: 10
backgroundRadius: 5
text: resetClears ? qsTr("Clear") : qsTr("Reset")
visible: root.onReset !== null
onClicked: root.onReset()
}
}
Label {

View File

@ -9,7 +9,7 @@ Item {
property string title: ""
property Item contentItem: null
property bool showRestoreDefaultsButton: true
signal restoreDefaultsClicked
signal restoreDefaults
onContentItemChanged: function() {
if (contentItem) {
@ -19,6 +19,13 @@ Item {
}
}
ConfirmationDialog {
id: restoreDefaultsDialog
dialogTitle: qsTr("Restore defaults?")
description: qsTr("This page of settings will be reset to the defaults.")
onAccepted: root.restoreDefaults()
}
ScrollView {
id: scrollView
width: parent.width
@ -47,6 +54,7 @@ Item {
Column {
id: contentInner
Layout.fillWidth: true
Layout.maximumWidth: parent.width
}
Item {
@ -63,9 +71,7 @@ Item {
Accessible.role: Accessible.Button
Accessible.name: text
Accessible.description: qsTr("Restores settings dialog to a default state")
onClicked: {
root.restoreDefaultsClicked();
}
onClicked: restoreDefaultsDialog.open()
}
}
}

View File

@ -5,18 +5,27 @@ import QtQuick.Controls.Basic
TextArea {
id: myTextArea
property string errState: "ok" // one of "ok", "error", "warning"
color: enabled ? theme.textColor : theme.mutedTextColor
placeholderTextColor: theme.mutedTextColor
font.pixelSize: theme.fontSizeLarge
background: Rectangle {
implicitWidth: 150
color: theme.controlBackground
border.width: 1
border.color: theme.controlBorder
border.width: errState === "ok" ? 1 : 2
border.color: {
switch (errState) {
case "ok": return theme.controlBorder;
case "warning": return theme.textWarningColor;
case "error": return theme.textErrorColor;
}
}
radius: 10
}
padding: 10
wrapMode: TextArea.Wrap
ToolTip.delay: Qt.styleHints.mousePressAndHoldInterval
}
}

View File

@ -16,6 +16,7 @@ Button {
property alias fillMode: image.fillMode
property alias imageWidth: image.sourceSize.width
property alias imageHeight: image.sourceSize.height
property alias bgTransform: background.transform
contentItem: Text {
text: myButton.text
horizontalAlignment: Text.AlignHCenter
@ -26,6 +27,7 @@ Button {
}
background: Item {
id: background
anchors.fill: parent
Rectangle {
anchors.fill: parent

View File

@ -1,46 +0,0 @@
import QtCore
import QtQuick
import QtQuick.Controls
import QtQuick.Controls.Basic
import QtQuick.Layouts
import llm
import mysettings
MyDialog {
id: switchModelDialog
anchors.centerIn: parent
modal: true
padding: 20
property int index: -1
Theme {
id: theme
}
contentItem: Text {
textFormat: Text.StyledText
text: qsTr("<b>Warning:</b> changing the model will erase the current conversation. Do you wish to continue?")
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
}
footer: DialogButtonBox {
id: dialogBox
padding: 20
alignment: Qt.AlignRight
spacing: 10
MySettingsButton {
text: qsTr("Continue")
Accessible.description: qsTr("Continue with model loading")
DialogButtonBox.buttonRole: DialogButtonBox.AcceptRole
}
MySettingsButton {
text: qsTr("Cancel")
Accessible.description: qsTr("Cancel")
DialogButtonBox.buttonRole: DialogButtonBox.RejectRole
}
background: Rectangle {
color: "transparent"
}
}
}

View File

@ -64,6 +64,9 @@ QtObject {
property color green800: Qt.hsla(123/360, 0.17, 0.24)
property color green900: Qt.hsla(124/360, 0.17, 0.20)
property color green950: Qt.hsla(125/360, 0.22, 0.10)
property color green300_sat: Qt.hsla(122/360, 0.24, 0.73)
property color green400_sat: Qt.hsla(122/360, 0.23, 0.58)
property color green450_sat: Qt.hsla(122/360, 0.23, 0.52)
// yellow
property color yellow0: Qt.hsla(47/360, 0.90, 0.99)
@ -99,6 +102,7 @@ QtObject {
property color purple200: Qt.hsla(279/360, 1.0, 0.91)
property color purple300: Qt.hsla(279/360, 1.0, 0.84)
property color purple400: Qt.hsla(279/360, 1.0, 0.73)
property color purple450: Qt.hsla(279/360, 1.0, 0.68)
property color purple500: Qt.hsla(279/360, 1.0, 0.63)
property color purple600: Qt.hsla(279/360, 1.0, 0.53)
property color purple700: Qt.hsla(279/360, 1.0, 0.47)
@ -408,6 +412,39 @@ QtObject {
}
}
property color mediumButtonBackground: {
switch (MySettings.chatTheme) {
case MySettingsEnums.ChatTheme.LegacyDark:
return purple400
case MySettingsEnums.ChatTheme.Dark:
return green400_sat
default:
return green400_sat
}
}
property color mediumButtonBackgroundHovered: {
switch (MySettings.chatTheme) {
case MySettingsEnums.ChatTheme.LegacyDark:
return purple450
case MySettingsEnums.ChatTheme.Dark:
return green450_sat
default:
return green300_sat
}
}
property color mediumButtonText: {
switch (MySettings.chatTheme) {
case MySettingsEnums.ChatTheme.LegacyDark:
return textColor
case MySettingsEnums.ChatTheme.Dark:
return textColor
default:
return white
}
}
property color darkButtonText: {
switch (MySettings.chatTheme) {
case MySettingsEnums.ChatTheme.LegacyDark:
@ -922,16 +959,8 @@ QtObject {
}
}
property color textErrorColor: {
switch (MySettings.chatTheme) {
case MySettingsEnums.ChatTheme.LegacyDark:
return red400
case MySettingsEnums.ChatTheme.Dark:
return red400
default:
return red400
}
}
readonly property color textErrorColor: red400
readonly property color textWarningColor: yellow400
property color settingsTitleTextColor: {
switch (MySettings.chatTheme) {

View File

@ -1,7 +1,6 @@
#include "chat.h"
#include "chatlistmodel.h"
#include "mysettings.h"
#include "network.h"
#include "server.h"
@ -11,7 +10,6 @@
#include <QLatin1String>
#include <QMap>
#include <QString>
#include <QStringList>
#include <QVariant>
#include <Qt>
#include <QtLogging>
@ -56,18 +54,18 @@ void Chat::connectLLM()
// Should be in different threads
connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseFailed, this, &Chat::handleResponseFailed, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
@ -75,11 +73,10 @@ void Chat::connectLLM()
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::QueuedConnection);
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::QueuedConnection);
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection);
connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection);
connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections);
connect(ModelList::globalInstance(), &ModelList::modelInfoChanged, this, &Chat::handleModelInfoChanged);
}
void Chat::reset()
@ -87,28 +84,17 @@ void Chat::reset()
stopGenerating();
// Erase our current on disk representation as we're completely resetting the chat along with id
ChatListModel::globalInstance()->removeChatFile(this);
emit resetContextRequested();
m_id = Network::globalInstance()->generateUniqueId();
emit idChanged(m_id);
// NOTE: We deliberately do no reset the name or creation date to indicate that this was originally
// an older chat that was reset for another purpose. Resetting this data will lead to the chat
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
// further down in the list. This might surprise the user. In the future, we might get rid of
// the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown
// we effectively do a reset context. We *have* to do this right now when switching between different
// types of models. The only way to get rid of that would be a very long recalculate where we rebuild
// the context if we switch between different types of models. Probably the right way to fix this
// is to allow switching models but throwing up a dialog warning users if we switch between types
// of models that a long recalculation will ensue.
// the "reset context" button in the UI.
m_chatModel->clear();
m_needsSave = true;
}
void Chat::processSystemPrompt()
{
emit processSystemPromptRequested();
}
void Chat::resetResponseState()
{
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
@ -160,25 +146,30 @@ void Chat::newPromptResponsePair(const QString &prompt, const QList<QUrl> &attac
if (!attachedContexts.isEmpty())
promptPlusAttached = attachedContexts.join("\n\n") + "\n\n" + prompt;
newPromptResponsePairInternal(prompt, attachments);
emit resetResponseRequested();
resetResponseState();
qsizetype prevMsgIndex = m_chatModel->count() - 1;
if (prevMsgIndex >= 0)
m_chatModel->updateCurrentResponse(prevMsgIndex, false);
m_chatModel->appendPrompt(prompt, attachments);
m_chatModel->appendResponse(prevMsgIndex + 1);
this->prompt(promptPlusAttached);
emit promptRequested(m_collections);
m_needsSave = true;
}
void Chat::prompt(const QString &prompt)
void Chat::regenerateResponse(int index)
{
resetResponseState();
emit promptRequested(m_collections, prompt);
emit regenerateResponseRequested(index);
m_needsSave = true;
}
void Chat::regenerateResponse()
QVariant Chat::popPrompt(int index)
{
const int index = m_chatModel->count() - 1;
m_chatModel->updateSources(index, QList<ResultInfo>());
emit regenerateResponseRequested();
auto content = m_llmodel->popPrompt(index);
m_needsSave = true;
if (content) return *content;
return QVariant::fromValue(nullptr);
}
void Chat::stopGenerating()
@ -202,6 +193,14 @@ void Chat::handleResponseChanged(const QString &response)
m_chatModel->updateValue(index, response);
}
void Chat::handleResponseFailed(const QString &error)
{
const int index = m_chatModel->count() - 1;
m_chatModel->updateValue(index, error);
m_chatModel->setError();
responseStopped(0);
}
void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
{
if (m_shouldDeleteLater)
@ -272,25 +271,6 @@ void Chat::setModelInfo(const ModelInfo &modelInfo)
emit modelChangeRequested(modelInfo);
}
// the server needs to block until response is reset, so it calls resetResponse on its own m_llmThread
void Chat::serverNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments)
{
newPromptResponsePairInternal(prompt, attachments);
}
void Chat::newPromptResponsePairInternal(const QString &prompt, const QList<PromptAttachment> &attachments)
{
resetResponseState();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt("Prompt: ", prompt, attachments);
m_chatModel->appendResponse("Response: ");
}
bool Chat::restoringFromText() const
{
return m_llmodel->restoringFromText();
}
void Chat::unloadAndDeleteLater()
{
if (!isModelLoaded()) {
@ -356,12 +336,6 @@ void Chat::generatedQuestionFinished(const QString &question)
m_needsSave = true;
}
void Chat::handleRestoringFromText()
{
Network::globalInstance()->trackChatEvent("recalc_context", { {"length", m_chatModel->count()} });
emit restoringFromTextChanged();
}
void Chat::handleModelLoadingError(const QString &error)
{
if (!error.isEmpty()) {
@ -396,12 +370,19 @@ QString Chat::fallbackReason() const
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
{
m_databaseResults = results;
const int index = m_chatModel->count() - 1;
m_chatModel->updateSources(index, m_databaseResults);
m_needsSave = true;
}
// we need to notify listeners of the modelInfo property when its properties are updated,
// since it's a gadget and can't do that on its own
void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
{
if (!m_modelInfo.id().isNull() && modelInfo.id() == m_modelInfo.id())
emit modelInfoChanged();
}
// react if a new model is loaded
void Chat::handleModelChanged(const ModelInfo &modelInfo)
{
if (m_modelInfo == modelInfo)
return;
@ -430,10 +411,7 @@ bool Chat::serialize(QDataStream &stream, int version) const
if (version >= 3)
stream << m_collections;
const bool serializeKV = MySettings::globalInstance()->saveChatsContext();
if (version >= 6)
stream << serializeKV;
if (!m_llmodel->serialize(stream, version, serializeKV))
if (!m_llmodel->serialize(stream, version))
return false;
if (!m_chatModel->serialize(stream, version))
return false;
@ -462,19 +440,13 @@ bool Chat::deserialize(QDataStream &stream, int version)
if (!m_modelInfo.id().isEmpty())
emit modelInfoChanged();
bool discardKV = m_modelInfo.id().isEmpty();
if (version >= 3) {
stream >> m_collections;
emit collectionListChanged(m_collections);
}
bool deserializeKV = true;
if (version >= 6)
stream >> deserializeKV;
m_llmodel->setModelInfo(m_modelInfo);
if (!m_llmodel->deserialize(stream, version, deserializeKV, discardKV))
if (!m_llmodel->deserialize(stream, version))
return false;
if (!m_chatModel->deserialize(stream, version))
return false;

View File

@ -12,6 +12,8 @@
#include <QObject>
#include <QQmlEngine>
#include <QString>
#include <QStringList> // IWYU pragma: keep
#include <QStringView>
#include <QtGlobal>
class QDataStream;
@ -27,7 +29,6 @@ class Chat : public QObject
Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged)
Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged)
Q_PROPERTY(ResponseState responseState READ responseState NOTIFY responseStateChanged)
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
@ -77,13 +78,12 @@ public:
bool isNewChat() const { return m_name == tr("New Chat") && !m_chatModel->count(); }
Q_INVOKABLE void reset();
Q_INVOKABLE void processSystemPrompt();
bool isModelLoaded() const { return m_modelLoadingPercentage == 1.0f; }
bool isCurrentlyLoading() const { return m_modelLoadingPercentage > 0.0f && m_modelLoadingPercentage < 1.0f; }
float modelLoadingPercentage() const { return m_modelLoadingPercentage; }
Q_INVOKABLE void newPromptResponsePair(const QString &prompt, const QList<QUrl> &attachedUrls = {});
Q_INVOKABLE void prompt(const QString &prompt);
Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void regenerateResponse(int index);
Q_INVOKABLE QVariant popPrompt(int index);
Q_INVOKABLE void stopGenerating();
QList<ResultInfo> databaseResults() const { return m_databaseResults; }
@ -92,7 +92,6 @@ public:
ResponseState responseState() const;
ModelInfo modelInfo() const;
void setModelInfo(const ModelInfo &modelInfo);
bool restoringFromText() const;
Q_INVOKABLE void unloadModel();
Q_INVOKABLE void reloadModel();
@ -113,7 +112,6 @@ public:
Q_INVOKABLE bool hasCollection(const QString &collection) const;
Q_INVOKABLE void addCollection(const QString &collection);
Q_INVOKABLE void removeCollection(const QString &collection);
void resetResponseState();
QString modelLoadingError() const { return m_modelLoadingError; }
@ -131,7 +129,7 @@ public:
void setNeedsSave(bool n) { m_needsSave = n; }
public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments = {});
void resetResponseState();
Q_SIGNALS:
void idChanged(const QString &id);
@ -143,14 +141,12 @@ Q_SIGNALS:
void modelLoadingWarning(const QString &warning);
void responseInProgressChanged();
void responseStateChanged();
void promptRequested(const QList<QString> &collectionList, const QString &prompt);
void regenerateResponseRequested();
void promptRequested(const QStringList &enabledCollections);
void regenerateResponseRequested(int index);
void resetResponseRequested();
void resetContextRequested();
void processSystemPromptRequested();
void modelChangeRequested(const ModelInfo &modelInfo);
void modelInfoChanged();
void restoringFromTextChanged();
void loadDefaultModelRequested();
void generateNameRequested();
void modelLoadingErrorChanged();
@ -166,22 +162,20 @@ Q_SIGNALS:
private Q_SLOTS:
void handleResponseChanged(const QString &response);
void handleResponseFailed(const QString &error);
void handleModelLoadingPercentageChanged(float);
void promptProcessing();
void generatingQuestions();
void responseStopped(qint64 promptResponseMs);
void generatedNameChanged(const QString &name);
void generatedQuestionFinished(const QString &question);
void handleRestoringFromText();
void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
void handleModelInfoChanged(const ModelInfo &modelInfo);
void handleModelChanged(const ModelInfo &modelInfo);
void handleTrySwitchContextOfLoadedModelCompleted(int value);
private:
void newPromptResponsePairInternal(const QString &prompt, const QList<PromptAttachment> &attachments);
private:
QString m_id;
QString m_name;

View File

@ -1,10 +1,10 @@
#include "chatapi.h"
#include <gpt4all-backend/llmodel.h>
#include "utils.h"
#include <QCoreApplication>
#include <QGuiApplication>
#include <QDebug>
#include <QGuiApplication>
#include <QJsonArray>
#include <QJsonDocument>
#include <QJsonObject>
@ -13,12 +13,17 @@
#include <QNetworkRequest>
#include <QThread>
#include <QUrl>
#include <QUtf8StringView>
#include <QVariant>
#include <QXmlStreamReader>
#include <Qt>
#include <QtGlobal>
#include <QtLogging>
#include <expected>
#include <functional>
#include <iostream>
#include <utility>
using namespace Qt::Literals::StringLiterals;
@ -67,71 +72,119 @@ bool ChatAPI::isModelLoaded() const
return true;
}
void ChatAPI::prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::optional<std::string_view> fakeReply) {
static auto parsePrompt(QXmlStreamReader &xml) -> std::expected<QJsonArray, QString>
{
QJsonArray messages;
Q_UNUSED(promptCallback);
Q_UNUSED(allowContextShift);
Q_UNUSED(special);
auto xmlError = [&xml] {
return std::unexpected(u"%1:%2: %3"_s.arg(xml.lineNumber()).arg(xml.columnNumber()).arg(xml.errorString()));
};
if (!isModelLoaded()) {
std::cerr << "ChatAPI ERROR: prompt won't work with an unloaded model!\n";
return;
if (xml.hasError())
return xmlError();
if (xml.atEnd())
return messages;
// skip header
bool foundElement = false;
do {
switch (xml.readNext()) {
using enum QXmlStreamReader::TokenType;
case Invalid:
return xmlError();
case EndDocument:
return messages;
default:
foundElement = true;
case StartDocument:
case Comment:
case DTD:
case ProcessingInstruction:
;
}
} while (!foundElement);
// document body loop
bool foundRoot = false;
for (;;) {
switch (xml.tokenType()) {
using enum QXmlStreamReader::TokenType;
case StartElement:
{
auto name = xml.name();
if (!foundRoot) {
if (name != "chat"_L1)
return std::unexpected(u"unexpected tag: %1"_s.arg(name));
foundRoot = true;
} else {
if (name != "user"_L1 && name != "assistant"_L1 && name != "system"_L1)
return std::unexpected(u"unknown role: %1"_s.arg(name));
auto content = xml.readElementText();
if (xml.tokenType() != EndElement)
return xmlError();
messages << makeJsonObject({
{ "role"_L1, name.toString().trimmed() },
{ "content"_L1, content },
});
}
break;
}
case Characters:
if (!xml.isWhitespace())
return std::unexpected(u"unexpected text: %1"_s.arg(xml.text()));
case Comment:
case ProcessingInstruction:
case EndElement:
break;
case EndDocument:
return messages;
case Invalid:
return xmlError();
default:
return std::unexpected(u"unexpected token: %1"_s.arg(xml.tokenString()));
}
xml.readNext();
}
}
if (!promptCtx.n_past) { m_queuedPrompts.clear(); }
Q_ASSERT(promptCtx.n_past <= m_context.size());
m_context.resize(promptCtx.n_past);
void ChatAPI::prompt(
std::string_view prompt,
const PromptCallback &promptCallback,
const ResponseCallback &responseCallback,
const PromptContext &promptCtx
) {
Q_UNUSED(promptCallback)
// FIXME(cebtenzzre): We're assuming people don't try to use %2 with ChatGPT. What would that even mean?
m_queuedPrompts << QString::fromStdString(promptTemplate).arg(QString::fromStdString(prompt));
if (!promptCtx.n_predict && !fakeReply) {
return; // response explicitly suppressed, queue prompt for later
}
QString formattedPrompt = m_queuedPrompts.join("");
m_queuedPrompts.clear();
if (fakeReply) {
promptCtx.n_past += 1;
m_context.append(formattedPrompt);
m_context.append(QString::fromUtf8(fakeReply->data(), fakeReply->size()));
return;
}
if (!isModelLoaded())
throw std::invalid_argument("Attempted to prompt an unloaded model.");
if (!promptCtx.n_predict)
return; // nothing requested
// FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering
// an error we need to be able to count the tokens in our prompt. The only way to do this is to use
// the OpenAI tiktokken library or to implement our own tokenization function that matches precisely
// the OpenAI tiktoken library or to implement our own tokenization function that matches precisely
// the tokenization used by the OpenAI model we're calling. OpenAI has not introduced any means of
// using the REST API to count tokens in a prompt.
QJsonObject root;
root.insert("model", m_modelName);
root.insert("stream", true);
root.insert("temperature", promptCtx.temp);
root.insert("top_p", promptCtx.top_p);
auto root = makeJsonObject({
{ "model"_L1, m_modelName },
{ "stream"_L1, true },
{ "temperature"_L1, promptCtx.temp },
{ "top_p"_L1, promptCtx.top_p },
});
// conversation history
QJsonArray messages;
for (int i = 0; i < m_context.count(); ++i) {
QJsonObject message;
message.insert("role", i % 2 == 0 ? "user" : "assistant");
message.insert("content", m_context.at(i));
messages.append(message);
{
QUtf8StringView promptUtf8(prompt);
QXmlStreamReader xml(promptUtf8);
auto messages = parsePrompt(xml);
if (!messages) {
auto error = fmt::format("Failed to parse API model prompt: {}", messages.error());
qDebug().noquote() << "ChatAPI ERROR:" << error << "Prompt:\n\n" << promptUtf8 << '\n';
throw std::invalid_argument(error);
}
root.insert("messages"_L1, *messages);
}
QJsonObject promptObject;
promptObject.insert("role", "user");
promptObject.insert("content", formattedPrompt);
messages.append(promptObject);
root.insert("messages", messages);
QJsonDocument doc(root);
#if defined(DEBUG)
@ -148,12 +201,9 @@ void ChatAPI::prompt(const std::string &prompt,
connect(&worker, &ChatAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(this, &ChatAPI::request, &worker, &ChatAPIWorker::request, Qt::QueuedConnection);
workerThread.start();
emit request(m_apiKey, &promptCtx, doc.toJson(QJsonDocument::Compact));
emit request(m_apiKey, doc.toJson(QJsonDocument::Compact));
workerThread.wait();
promptCtx.n_past += 1;
m_context.append(formattedPrompt);
m_context.append(worker.currentResponse());
m_responseCallback = nullptr;
#if defined(DEBUG)
@ -171,12 +221,8 @@ bool ChatAPI::callResponse(int32_t token, const std::string& string)
return m_responseCallback(token, string);
}
void ChatAPIWorker::request(const QString &apiKey,
LLModel::PromptContext *promptCtx,
const QByteArray &array)
void ChatAPIWorker::request(const QString &apiKey, const QByteArray &array)
{
m_ctx = promptCtx;
QUrl apiUrl(m_chat->url());
const QString authorization = u"Bearer %1"_s.arg(apiKey).trimmed();
QNetworkRequest request(apiUrl);
@ -283,7 +329,6 @@ void ChatAPIWorker::handleReadyRead()
const QJsonObject choice = choices.first().toObject();
const QJsonObject delta = choice.value("delta").toObject();
const QString content = delta.value("content").toString();
Q_ASSERT(m_ctx);
m_currentResponse += content;
if (!m_chat->callResponse(0, content.toStdString())) {
reply->abort();

View File

@ -7,17 +7,14 @@
#include <QNetworkReply>
#include <QObject>
#include <QString>
#include <QStringList>
#include <QList>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <optional>
#include <span>
#include <stdexcept>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>
class QNetworkAccessManager;
@ -28,16 +25,13 @@ class ChatAPIWorker : public QObject {
public:
ChatAPIWorker(ChatAPI *chatAPI)
: QObject(nullptr)
, m_ctx(nullptr)
, m_networkManager(nullptr)
, m_chat(chatAPI) {}
virtual ~ChatAPIWorker() {}
QString currentResponse() const { return m_currentResponse; }
void request(const QString &apiKey,
LLModel::PromptContext *promptCtx,
const QByteArray &array);
void request(const QString &apiKey, const QByteArray &array);
Q_SIGNALS:
void finished();
@ -49,7 +43,6 @@ private Q_SLOTS:
private:
ChatAPI *m_chat;
LLModel::PromptContext *m_ctx;
QNetworkAccessManager *m_networkManager;
QString m_currentResponse;
};
@ -74,14 +67,14 @@ public:
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override
{ Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); }
void prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &ctx,
bool special,
std::optional<std::string_view> fakeReply) override;
void prompt(std::string_view prompt,
const PromptCallback &promptCallback,
const ResponseCallback &responseCallback,
const PromptContext &ctx) override;
[[noreturn]]
int32_t countPromptTokens(std::string_view prompt) const override
{ Q_UNUSED(prompt); throwNotImplemented(); }
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;
@ -91,19 +84,17 @@ public:
void setRequestURL(const QString &requestURL) { m_requestURL = requestURL; }
QString url() const { return m_requestURL; }
QList<QString> context() const { return m_context; }
void setContext(const QList<QString> &context) { m_context = context; }
bool callResponse(int32_t token, const std::string &string);
[[noreturn]]
int32_t contextLength() const override
{ throwNotImplemented(); }
auto specialTokens() -> std::unordered_map<std::string, std::string> const override
{ return {}; }
Q_SIGNALS:
void request(const QString &apiKey,
LLModel::PromptContext *ctx,
const QByteArray &array);
void request(const QString &apiKey, const QByteArray &array);
protected:
// We have to implement these as they are pure virtual in base class, but we don't actually use
@ -114,8 +105,8 @@ protected:
static void throwNotImplemented() { throw std::logic_error("not implemented"); }
[[noreturn]]
std::vector<Token> tokenize(std::string_view str, bool special) override
{ Q_UNUSED(str); Q_UNUSED(special); throwNotImplemented(); }
std::vector<Token> tokenize(std::string_view str) const override
{ Q_UNUSED(str); throwNotImplemented(); }
[[noreturn]]
bool isSpecialToken(Token id) const override
@ -126,7 +117,7 @@ protected:
{ Q_UNUSED(id); throwNotImplemented(); }
[[noreturn]]
void initSampler(PromptContext &ctx) override
void initSampler(const PromptContext &ctx) override
{ Q_UNUSED(ctx); throwNotImplemented(); }
[[noreturn]]
@ -134,33 +125,28 @@ protected:
{ throwNotImplemented(); }
[[noreturn]]
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override
{ Q_UNUSED(ctx); Q_UNUSED(tokens); throwNotImplemented(); }
bool evalTokens(int32_t nPast, std::span<const Token> tokens) const override
{ Q_UNUSED(nPast); Q_UNUSED(tokens); throwNotImplemented(); }
[[noreturn]]
void shiftContext(PromptContext &promptCtx) override
{ Q_UNUSED(promptCtx); throwNotImplemented(); }
void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override
{ Q_UNUSED(promptCtx); Q_UNUSED(nPast); throwNotImplemented(); }
[[noreturn]]
int32_t inputLength() const override
{ throwNotImplemented(); }
[[noreturn]]
void setTokenizeInputPosition(int32_t pos) override
int32_t computeModelInputPosition(std::span<const Token> input) const override
{ Q_UNUSED(input); throwNotImplemented(); }
[[noreturn]]
void setModelInputPosition(int32_t pos) override
{ Q_UNUSED(pos); throwNotImplemented(); }
[[noreturn]]
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator override
{ Q_UNUSED(ctx); Q_UNUSED(input); throwNotImplemented(); }
[[noreturn]]
void setModelInputPosition(PromptContext &ctx, int32_t pos) override
{ Q_UNUSED(ctx); Q_UNUSED(pos); throwNotImplemented(); }
[[noreturn]]
void appendInputToken(PromptContext &ctx, Token tok) override
{ Q_UNUSED(ctx); Q_UNUSED(tok); throwNotImplemented(); }
void appendInputToken(Token tok) override
{ Q_UNUSED(tok); throwNotImplemented(); }
[[noreturn]]
const std::vector<Token> &endTokens() const override
@ -175,12 +161,10 @@ protected:
{ throwNotImplemented(); }
private:
std::function<bool(int32_t, const std::string&)> m_responseCallback;
QString m_modelName;
QString m_apiKey;
QString m_requestURL;
QList<QString> m_context;
QStringList m_queuedPrompts;
ResponseCallback m_responseCallback;
QString m_modelName;
QString m_apiKey;
QString m_requestURL;
};
#endif // CHATAPI_H

View File

@ -17,9 +17,10 @@
#include <Qt>
#include <algorithm>
#include <memory>
#define CHAT_FORMAT_MAGIC 0xF5D553CC
#define CHAT_FORMAT_VERSION 10
static constexpr quint32 CHAT_FORMAT_MAGIC = 0xF5D553CC;
static constexpr qint32 CHAT_FORMAT_VERSION = 11;
class MyChatListModel: public ChatListModel { };
Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance)
@ -118,8 +119,8 @@ void ChatSaver::saveChats(const QVector<Chat *> &chats)
}
QDataStream out(&tempFile);
out << (quint32)CHAT_FORMAT_MAGIC;
out << (qint32)CHAT_FORMAT_VERSION;
out << CHAT_FORMAT_MAGIC;
out << CHAT_FORMAT_VERSION;
out.setVersion(QDataStream::Qt_6_2);
qDebug() << "serializing chat" << fileName;
@ -257,12 +258,15 @@ void ChatsRestoreThread::run()
qDebug() << "deserializing chat" << f.file;
Chat *chat = new Chat;
auto chat = std::make_unique<Chat>();
chat->moveToThread(qGuiApp->thread());
if (!chat->deserialize(in, version)) {
bool ok = chat->deserialize(in, version);
if (!ok) {
qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName();
} else if (!in.atEnd()) {
qWarning().nospace() << "error loading chat from " << file.fileName() << ": extra data at end of file";
} else {
emit chatRestored(chat);
emit chatRestored(chat.release());
}
if (f.oldFile)
file.remove(); // No longer storing in this directory

File diff suppressed because it is too large Load Diff

View File

@ -13,20 +13,24 @@
#include <QObject>
#include <QPointer>
#include <QString>
#include <QStringList> // IWYU pragma: keep
#include <QStringView>
#include <QThread>
#include <QVariantMap>
#include <QVariantMap> // IWYU pragma: keep
#include <QtGlobal>
#include <atomic>
#include <cstdint>
#include <memory>
#include <optional>
#include <span>
#include <string>
#include <vector>
#include <variant>
using namespace Qt::Literals::StringLiterals;
class QDataStream;
struct ChatItem;
// NOTE: values serialized to disk, do not change or reuse
enum class LLModelTypeV0 { // chat versions 2-5
@ -142,7 +146,6 @@ class Chat;
class ChatLLM : public QObject
{
Q_OBJECT
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
@ -150,12 +153,14 @@ public:
ChatLLM(Chat *parent, bool isServer = false);
virtual ~ChatLLM();
void destroy();
static void destroyStore();
static std::optional<std::string> checkJinjaTemplateError(const std::string &source);
void destroy();
bool isModelLoaded() const;
void regenerateResponse();
void resetResponse();
void resetContext();
void regenerateResponse(int index);
// used to implement edit functionality
std::optional<QString> popPrompt(int index);
void stopGenerating() { m_stopGenerating = true; }
@ -165,13 +170,9 @@ public:
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
QString response(bool trim = true) const;
ModelInfo modelInfo() const;
void setModelInfo(const ModelInfo &info);
bool restoringFromText() const { return m_restoringFromText; }
void acquireModel();
void resetModel();
@ -196,13 +197,11 @@ public:
return m_llModelInfo.fallbackReason.value_or(u""_s);
}
QString generatedName() const { return QString::fromStdString(m_nameResponse); }
bool serialize(QDataStream &stream, int version, bool serializeKV);
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
bool serialize(QDataStream &stream, int version);
bool deserialize(QDataStream &stream, int version);
public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt);
void prompt(const QStringList &enabledCollections);
bool loadDefaultModel();
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
bool loadModel(const ModelInfo &modelInfo);
@ -210,22 +209,19 @@ public Q_SLOTS:
void unloadModel();
void reloadModel();
void generateName();
void generateQuestions(qint64 elapsed);
void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged();
void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged();
void processSystemPrompt();
void processRestoreStateFromText();
Q_SIGNALS:
void restoringFromTextChanged();
void loadedModelInfoChanged();
void modelLoadingPercentageChanged(float);
void modelLoadingError(const QString &error);
void modelLoadingWarning(const QString &warning);
void responseChanged(const QString &response);
void responseFailed(const QString &error);
void promptProcessing();
void generatingQuestions();
void responseStopped(qint64 promptResponseMs);
@ -244,58 +240,50 @@ Q_SIGNALS:
void modelInfoChanged(const ModelInfo &modelInfo);
protected:
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens, std::optional<QString> fakeReply = {});
bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response);
bool handleNamePrompt(int32_t token);
bool handleNameResponse(int32_t token, const std::string &response);
bool handleSystemPrompt(int32_t token);
bool handleSystemResponse(int32_t token, const std::string &response);
bool handleRestoreStateFromTextPrompt(int32_t token);
bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response);
bool handleQuestionPrompt(int32_t token);
bool handleQuestionResponse(int32_t token, const std::string &response);
void saveState();
void restoreState();
struct PromptResult {
QByteArray response; // raw UTF-8
int promptTokens; // note: counts *entire* history, even if cached
int responseTokens;
};
protected:
LLModel::PromptContext m_ctx;
quint32 m_promptTokens;
quint32 m_promptResponseTokens;
struct ChatPromptResult : PromptResult {
QList<ResultInfo> databaseResults;
};
ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx);
// passing a string_view directly skips templating and uses the raw string
PromptResult promptInternal(const std::variant<std::span<const ChatItem>, std::string_view> &prompt,
const LLModel::PromptContext &ctx,
bool usedLocalDocs);
private:
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
std::vector<ChatItem> forkConversation(const QString &prompt) const;
// Applies the Jinja template. Query mode returns only the last message without special tokens.
// Returns a (# of messages, rendered prompt) pair.
std::string applyJinjaTemplate(std::span<const ChatItem> items) const;
void generateQuestions(qint64 elapsed);
protected:
QPointer<ChatModel> m_chatModel;
private:
const Chat *m_chat;
std::string m_response;
std::string m_trimmedResponse;
std::string m_nameResponse;
QString m_questionResponse;
LLModelInfo m_llModelInfo;
LLModelTypeV1 m_llModelType = LLModelTypeV1::NONE;
ModelInfo m_modelInfo;
TokenTimer *m_timer;
QByteArray m_state;
std::vector<LLModel::Token> m_stateInputTokens;
int32_t m_stateContextLength = -1;
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded;
std::atomic<bool> m_restoringFromText; // status indication
std::atomic<bool> m_forceUnloadModel;
std::atomic<bool> m_markedForDeletion;
bool m_isServer;
bool m_forceMetal;
bool m_reloadingToChangeVariant;
bool m_processedSystemPrompt;
bool m_restoreStateFromText;
// m_pristineLoadedState is set if saveSate is unnecessary, either because:
// - an unload was queued during LLModel::restoreState()
// - the chat will be restored from text and hasn't been interacted with yet
bool m_pristineLoadedState = false;
QPointer<ChatModel> m_chatModel;
};
#endif // CHATLLM_H

View File

@ -2,8 +2,11 @@
#define CHATMODEL_H
#include "database.h"
#include "utils.h"
#include "xlsxtomd.h"
#include <fmt/format.h>
#include <QAbstractListModel>
#include <QBuffer>
#include <QByteArray>
@ -18,6 +21,15 @@
#include <Qt>
#include <QtGlobal>
#include <iterator>
#include <ranges>
#include <span>
#include <utility>
using namespace Qt::Literals::StringLiterals;
namespace ranges = std::ranges;
struct PromptAttachment {
Q_GADGET
Q_PROPERTY(QUrl url MEMBER url)
@ -60,66 +72,145 @@ Q_DECLARE_METATYPE(PromptAttachment)
struct ChatItem
{
Q_GADGET
Q_PROPERTY(QString name MEMBER name)
Q_PROPERTY(QString name MEMBER name )
Q_PROPERTY(QString value MEMBER value)
Q_PROPERTY(QString newResponse MEMBER newResponse)
Q_PROPERTY(bool currentResponse MEMBER currentResponse)
Q_PROPERTY(bool stopped MEMBER stopped)
Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState)
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)
Q_PROPERTY(QList<ResultInfo> sources MEMBER sources)
Q_PROPERTY(QList<ResultInfo> consolidatedSources MEMBER consolidatedSources)
// prompts
Q_PROPERTY(QList<PromptAttachment> promptAttachments MEMBER promptAttachments)
Q_PROPERTY(QString promptPlusAttachments READ promptPlusAttachments)
Q_PROPERTY(QString bakedPrompt READ bakedPrompt )
// responses
Q_PROPERTY(bool isCurrentResponse MEMBER isCurrentResponse)
Q_PROPERTY(bool isError MEMBER isError )
// responses (DataLake)
Q_PROPERTY(QString newResponse MEMBER newResponse )
Q_PROPERTY(bool stopped MEMBER stopped )
Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState )
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)
public:
QString promptPlusAttachments() const
{
QStringList attachedContexts;
for (auto attached : promptAttachments)
attachedContexts << attached.processedContent();
enum class Type { System, Prompt, Response };
QString promptPlus = value;
if (!attachedContexts.isEmpty())
promptPlus = attachedContexts.join("\n\n") + "\n\n" + value;
return promptPlus;
// tags for constructing ChatItems
struct prompt_tag_t { explicit prompt_tag_t() = default; };
static inline constexpr prompt_tag_t prompt_tag = prompt_tag_t();
struct response_tag_t { explicit response_tag_t() = default; };
static inline constexpr response_tag_t response_tag = response_tag_t();
struct system_tag_t { explicit system_tag_t() = default; };
static inline constexpr system_tag_t system_tag = system_tag_t();
// FIXME(jared): This should not be necessary. QML should see null or undefined if it
// tries to access something invalid.
ChatItem() = default;
// NOTE: system messages are currently never stored in the model or serialized
ChatItem(system_tag_t, const QString &value)
: name(u"System: "_s), value(value) {}
ChatItem(prompt_tag_t, const QString &value, const QList<PromptAttachment> &attachments = {})
: name(u"Prompt: "_s), value(value), promptAttachments(attachments) {}
ChatItem(response_tag_t, bool isCurrentResponse = true)
: name(u"Response: "_s), isCurrentResponse(isCurrentResponse) {}
Type type() const
{
if (name == u"System: "_s)
return Type::System;
if (name == u"Prompt: "_s)
return Type::Prompt;
if (name == u"Response: "_s)
return Type::Response;
throw std::invalid_argument(fmt::format("Chat item has unknown label: {:?}", name));
}
// used with version 0 Jinja templates
QString bakedPrompt() const
{
if (type() != Type::Prompt)
throw std::logic_error("bakedPrompt() called on non-prompt item");
QStringList parts;
if (!sources.isEmpty()) {
parts << u"### Context:\n"_s;
for (auto &source : std::as_const(sources))
parts << u"Collection: "_s << source.collection
<< u"\nPath: "_s << source.path
<< u"\nExcerpt: "_s << source.text << u"\n\n"_s;
}
for (auto &attached : std::as_const(promptAttachments))
parts << attached.processedContent() << u"\n\n"_s;
parts << value;
return parts.join(QString());
}
// TODO: Maybe we should include the model name here as well as timestamp?
QString name;
QString value;
QString newResponse;
QList<ResultInfo> sources;
QList<ResultInfo> consolidatedSources;
// prompts
QList<ResultInfo> sources;
QList<ResultInfo> consolidatedSources;
QList<PromptAttachment> promptAttachments;
bool currentResponse = false;
bool stopped = false;
bool thumbsUpState = false;
bool thumbsDownState = false;
// responses
bool isCurrentResponse = false;
bool isError = false;
// responses (DataLake)
QString newResponse;
bool stopped = false;
bool thumbsUpState = false;
bool thumbsDownState = false;
};
Q_DECLARE_METATYPE(ChatItem)
using ChatModelIterator = QList<ChatItem>::const_iterator;
class ChatModelAccessor : public ranges::subrange<QList<ChatItem>::const_iterator> {
private:
using Super = ranges::subrange<QList<ChatItem>::const_iterator>;
public:
template <typename... T>
ChatModelAccessor(QMutex &mutex, T &&...args)
: Super(std::forward<T>(args)...), m_lock(&mutex) {}
private:
QMutexLocker<QMutex> m_lock;
};
class ChatModel : public QAbstractListModel
{
Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged)
Q_PROPERTY(bool hasError READ hasError NOTIFY hasErrorChanged)
public:
explicit ChatModel(QObject *parent = nullptr) : QAbstractListModel(parent) {}
explicit ChatModel(QObject *parent = nullptr)
: QAbstractListModel(parent) {}
// FIXME(jared): can't this start at Qt::UserRole (no +1)?
enum Roles {
NameRole = Qt::UserRole + 1,
ValueRole,
// prompts and responses
PeerRole,
// prompts
PromptAttachmentsRole,
// responses
// NOTE: sources are stored on the *prompts*, but in the model, they are only on the *responses*!
SourcesRole,
ConsolidatedSourcesRole,
IsCurrentResponseRole,
IsErrorRole,
// responses (DataLake)
NewResponseRole,
CurrentResponseRole,
StoppedRole,
ThumbsUpStateRole,
ThumbsDownStateRole,
SourcesRole,
ConsolidatedSourcesRole,
PromptAttachmentsRole
};
int rowCount(const QModelIndex &parent = QModelIndex()) const override
@ -129,34 +220,96 @@ public:
return m_chatItems.size();
}
/* a "peer" is a bidirectional 1:1 link between a prompt and the response that would cite its LocalDocs
* sources. Return std::nullopt if there is none, which is possible for e.g. server chats. */
auto getPeerUnlocked(QList<ChatItem>::const_iterator item) const
-> std::optional<QList<ChatItem>::const_iterator>
{
switch (item->type()) {
using enum ChatItem::Type;
case Prompt:
{
auto peer = std::next(item);
if (peer < m_chatItems.cend() && peer->type() == Response)
return peer;
break;
}
case Response:
{
if (item > m_chatItems.cbegin()) {
if (auto peer = std::prev(item); peer->type() == Prompt)
return peer;
}
break;
}
default:
throw std::invalid_argument("getPeer() called on item that is not a prompt or response");
}
return std::nullopt;
}
auto getPeerUnlocked(int index) const -> std::optional<int>
{
return getPeerUnlocked(m_chatItems.cbegin() + index)
.transform([&](auto &&i) { return i - m_chatItems.cbegin(); } );
}
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override
{
QMutexLocker locker(&m_mutex);
if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size())
return QVariant();
const ChatItem &item = m_chatItems.at(index.row());
auto item = m_chatItems.cbegin() + index.row();
switch (role) {
case NameRole:
return item.name;
return item->name;
case ValueRole:
return item.value;
case NewResponseRole:
return item.newResponse;
case CurrentResponseRole:
return item.currentResponse;
case StoppedRole:
return item.stopped;
case ThumbsUpStateRole:
return item.thumbsUpState;
case ThumbsDownStateRole:
return item.thumbsDownState;
case SourcesRole:
return QVariant::fromValue(item.sources);
case ConsolidatedSourcesRole:
return QVariant::fromValue(item.consolidatedSources);
return item->value;
case PeerRole:
switch (item->type()) {
using enum ChatItem::Type;
case Prompt:
case Response:
{
auto peer = getPeerUnlocked(item);
return peer ? QVariant::fromValue(**peer) : QVariant::fromValue(nullptr);
}
default:
return QVariant();
}
case PromptAttachmentsRole:
return QVariant::fromValue(item.promptAttachments);
return QVariant::fromValue(item->promptAttachments);
case SourcesRole:
{
QList<ResultInfo> data;
if (item->type() == ChatItem::Type::Response) {
if (auto prompt = getPeerUnlocked(item))
data = (*prompt)->consolidatedSources;
}
return QVariant::fromValue(data);
}
case ConsolidatedSourcesRole:
{
QList<ResultInfo> data;
if (item->type() == ChatItem::Type::Response) {
if (auto prompt = getPeerUnlocked(item))
data = (*prompt)->sources;
}
return QVariant::fromValue(data);
}
case IsCurrentResponseRole:
return item->isCurrentResponse;
case NewResponseRole:
return item->newResponse;
case StoppedRole:
return item->stopped;
case ThumbsUpStateRole:
return item->thumbsUpState;
case ThumbsDownStateRole:
return item->thumbsDownState;
case IsErrorRole:
return item->type() == ChatItem::Type::Response && item->isError;
}
return QVariant();
@ -164,54 +317,126 @@ public:
QHash<int, QByteArray> roleNames() const override
{
QHash<int, QByteArray> roles;
roles[NameRole] = "name";
roles[ValueRole] = "value";
roles[NewResponseRole] = "newResponse";
roles[CurrentResponseRole] = "currentResponse";
roles[StoppedRole] = "stopped";
roles[ThumbsUpStateRole] = "thumbsUpState";
roles[ThumbsDownStateRole] = "thumbsDownState";
roles[SourcesRole] = "sources";
roles[ConsolidatedSourcesRole] = "consolidatedSources";
roles[PromptAttachmentsRole] = "promptAttachments";
return roles;
return {
{ NameRole, "name" },
{ ValueRole, "value" },
{ PeerRole, "peer" },
{ PromptAttachmentsRole, "promptAttachments" },
{ SourcesRole, "sources" },
{ ConsolidatedSourcesRole, "consolidatedSources" },
{ IsCurrentResponseRole, "isCurrentResponse" },
{ IsErrorRole, "isError" },
{ NewResponseRole, "newResponse" },
{ StoppedRole, "stopped" },
{ ThumbsUpStateRole, "thumbsUpState" },
{ ThumbsDownStateRole, "thumbsDownState" },
};
}
void appendPrompt(const QString &name, const QString &value, const QList<PromptAttachment> &attachments)
void appendPrompt(const QString &value, const QList<PromptAttachment> &attachments = {})
{
ChatItem item;
item.name = name;
item.value = value;
item.promptAttachments << attachments;
qsizetype count;
{
QMutexLocker locker(&m_mutex);
if (hasErrorUnlocked())
throw std::logic_error("cannot append to a failed chat");
count = m_chatItems.count();
}
m_mutex.lock();
const int count = m_chatItems.count();
m_mutex.unlock();
beginInsertRows(QModelIndex(), count, count);
{
QMutexLocker locker(&m_mutex);
m_chatItems.append(item);
m_chatItems.emplace_back(ChatItem::prompt_tag, value, attachments);
}
endInsertRows();
emit countChanged();
}
void appendResponse(const QString &name)
void appendResponse(int promptIndex)
{
m_mutex.lock();
const int count = m_chatItems.count();
m_mutex.unlock();
ChatItem item;
item.name = name;
item.currentResponse = true;
qsizetype count;
{
QMutexLocker locker(&m_mutex);
if (hasErrorUnlocked())
throw std::logic_error("cannot append to a failed chat");
count = m_chatItems.count();
}
beginInsertRows(QModelIndex(), count, count);
{
QMutexLocker locker(&m_mutex);
m_chatItems.append(item);
if (promptIndex >= 0) {
if (promptIndex >= m_chatItems.size())
throw std::out_of_range(fmt::format("index {} is out of range", promptIndex));
auto &promptItem = m_chatItems[promptIndex];
if (promptItem.type() != ChatItem::Type::Prompt)
throw std::invalid_argument(fmt::format("item at index {} is not a prompt", promptIndex));
}
m_chatItems.emplace_back(ChatItem::response_tag, promptIndex);
}
endInsertRows();
emit countChanged();
if (promptIndex >= 0)
emit dataChanged(createIndex(promptIndex, 0), createIndex(promptIndex, 0), {PeerRole});
}
// Used by Server to append a new conversation to the chat log.
void appendResponseWithHistory(std::span<const ChatItem> history)
{
if (history.empty())
throw std::invalid_argument("at least one message is required");
m_mutex.lock();
qsizetype startIndex = m_chatItems.count();
m_mutex.unlock();
qsizetype nNewItems = history.size() + 1;
qsizetype endIndex = startIndex + nNewItems;
beginInsertRows(QModelIndex(), startIndex, endIndex - 1 /*inclusive*/);
bool hadError;
int promptIndex;
{
QMutexLocker locker(&m_mutex);
hadError = hasErrorUnlocked();
m_chatItems.reserve(m_chatItems.count() + nNewItems);
for (auto &item : history)
m_chatItems << item;
m_chatItems.emplace_back(ChatItem::response_tag);
}
endInsertRows();
emit countChanged();
// Server can add messages when there is an error because each call is a new conversation
if (hadError)
emit hasErrorChanged(false);
if (promptIndex >= 0)
emit dataChanged(createIndex(promptIndex, 0), createIndex(promptIndex, 0), {PeerRole});
}
void truncate(qsizetype size)
{
qsizetype oldSize;
{
QMutexLocker locker(&m_mutex);
if (size >= (oldSize = m_chatItems.size()))
return;
if (size && m_chatItems.at(size - 1).type() != ChatItem::Type::Response)
throw std::invalid_argument(
fmt::format("chat model truncated to {} items would not end in a response", size)
);
}
bool oldHasError;
beginRemoveRows(QModelIndex(), size, oldSize - 1 /*inclusive*/);
{
QMutexLocker locker(&m_mutex);
oldHasError = hasErrorUnlocked();
Q_ASSERT(size < m_chatItems.size());
m_chatItems.resize(size);
}
endRemoveRows();
emit countChanged();
if (oldHasError)
emit hasErrorChanged(false);
}
Q_INVOKABLE void clear()
@ -221,13 +446,17 @@ public:
if (m_chatItems.isEmpty()) return;
}
bool oldHasError;
beginResetModel();
{
QMutexLocker locker(&m_mutex);
oldHasError = hasErrorUnlocked();
m_chatItems.clear();
}
endResetModel();
emit countChanged();
if (oldHasError)
emit hasErrorChanged(false);
}
Q_INVOKABLE ChatItem get(int index)
@ -245,13 +474,13 @@ public:
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.currentResponse != b) {
item.currentResponse = b;
if (item.isCurrentResponse != b) {
item.isCurrentResponse = b;
changed = true;
}
}
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole});
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsCurrentResponseRole});
}
Q_INVOKABLE void updateStopped(int index, bool b)
@ -304,16 +533,23 @@ public:
Q_INVOKABLE void updateSources(int index, const QList<ResultInfo> &sources)
{
int responseIndex = -1;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
item.sources = sources;
item.consolidatedSources = consolidateSources(sources);
auto promptItem = m_chatItems.begin() + index;
if (promptItem->type() != ChatItem::Type::Prompt)
throw std::invalid_argument(fmt::format("item at index {} is not a prompt", index));
if (auto peer = getPeerUnlocked(promptItem))
responseIndex = *peer - m_chatItems.cbegin();
promptItem->sources = sources;
promptItem->consolidatedSources = consolidateSources(sources);
}
if (responseIndex >= 0) {
emit dataChanged(createIndex(responseIndex, 0), createIndex(responseIndex, 0), {SourcesRole});
emit dataChanged(createIndex(responseIndex, 0), createIndex(responseIndex, 0), {ConsolidatedSourcesRole});
}
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole});
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole});
}
Q_INVOKABLE void updateThumbsUpState(int index, bool b)
@ -364,18 +600,56 @@ public:
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
}
int count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); }
Q_INVOKABLE void setError(bool value = true)
{
qsizetype index;
{
QMutexLocker locker(&m_mutex);
ChatModelIterator begin() const { return m_chatItems.begin(); }
ChatModelIterator end() const { return m_chatItems.end(); }
void lock() { m_mutex.lock(); }
void unlock() { m_mutex.unlock(); }
if (m_chatItems.isEmpty() || m_chatItems.cend()[-1].type() != ChatItem::Type::Response)
throw std::logic_error("can only set error on a chat that ends with a response");
index = m_chatItems.count() - 1;
auto &last = m_chatItems.back();
if (last.isError == value)
return; // already set
last.isError = value;
}
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsErrorRole});
emit hasErrorChanged(value);
}
qsizetype count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); }
ChatModelAccessor chatItems() const { return {m_mutex, std::as_const(m_chatItems)}; }
bool hasError() const { QMutexLocker locker(&m_mutex); return hasErrorUnlocked(); }
bool serialize(QDataStream &stream, int version) const
{
QMutexLocker locker(&m_mutex);
stream << int(m_chatItems.size());
for (const auto &c : m_chatItems) {
for (auto itemIt = m_chatItems.cbegin(); itemIt < m_chatItems.cend(); ++itemIt) {
auto c = *itemIt; // NB: copies
if (version < 11) {
// move sources from their prompt to the next response
switch (c.type()) {
using enum ChatItem::Type;
case Prompt:
c.sources.clear();
c.consolidatedSources.clear();
break;
case Response:
// note: we drop sources for responseless prompts
if (auto peer = getPeerUnlocked(itemIt)) {
c.sources = (*peer)->sources;
c.consolidatedSources = (*peer)->consolidatedSources;
}
default:
;
}
}
// FIXME: This 'id' should be eliminated the next time we bump serialization version.
// (Jared) This was apparently never used.
int id = 0;
@ -383,10 +657,12 @@ public:
stream << c.name;
stream << c.value;
stream << c.newResponse;
stream << c.currentResponse;
stream << c.isCurrentResponse;
stream << c.stopped;
stream << c.thumbsUpState;
stream << c.thumbsDownState;
if (version >= 11 && c.type() == ChatItem::Type::Response)
stream << c.isError;
if (version >= 8) {
stream << c.sources.size();
for (const ResultInfo &info : c.sources) {
@ -452,14 +728,24 @@ public:
bool deserialize(QDataStream &stream, int version)
{
clear(); // reset to known state
int size;
stream >> size;
int lastPromptIndex = -1;
QList<ChatItem> chatItems;
for (int i = 0; i < size; ++i) {
ChatItem c;
// FIXME: see comment in serialization about id
int id;
stream >> id;
stream >> c.name;
try {
c.type(); // check name
} catch (const std::exception &e) {
qWarning() << "ChatModel ERROR:" << e.what();
return false;
}
stream >> c.value;
if (version < 10) {
// This is deprecated and no longer used
@ -467,10 +753,12 @@ public:
stream >> prompt;
}
stream >> c.newResponse;
stream >> c.currentResponse;
stream >> c.isCurrentResponse;
stream >> c.stopped;
stream >> c.thumbsUpState;
stream >> c.thumbsDownState;
if (version >= 11 && c.type() == ChatItem::Type::Response)
stream >> c.isError;
if (version >= 8) {
qsizetype count;
stream >> count;
@ -587,23 +875,53 @@ public:
}
c.promptAttachments = attachments;
}
m_mutex.lock();
const int count = m_chatItems.size();
m_mutex.unlock();
beginInsertRows(QModelIndex(), count, count);
{
QMutexLocker locker(&m_mutex);
m_chatItems.append(c);
if (version < 11 && c.type() == ChatItem::Type::Response) {
// move sources from the response to their last prompt
if (lastPromptIndex >= 0) {
auto &prompt = chatItems[lastPromptIndex];
prompt.sources = std::move(c.sources );
prompt.consolidatedSources = std::move(c.consolidatedSources);
lastPromptIndex = -1;
} else {
// drop sources for promptless responses
c.sources.clear();
c.consolidatedSources.clear();
}
}
endInsertRows();
chatItems << c;
if (c.type() == ChatItem::Type::Prompt)
lastPromptIndex = chatItems.size() - 1;
}
bool hasError;
beginInsertRows(QModelIndex(), 0, chatItems.size() - 1 /*inclusive*/);
{
QMutexLocker locker(&m_mutex);
m_chatItems = chatItems;
hasError = hasErrorUnlocked();
}
endInsertRows();
emit countChanged();
if (hasError)
emit hasErrorChanged(true);
return stream.status() == QDataStream::Ok;
}
Q_SIGNALS:
void countChanged();
void valueChanged(int index, const QString &value);
void hasErrorChanged(bool value);
private:
bool hasErrorUnlocked() const
{
if (m_chatItems.isEmpty())
return false;
auto &last = m_chatItems.back();
return last.type() == ChatItem::Type::Response && last.isError;
}
private:
mutable QMutex m_mutex;

View File

@ -0,0 +1,111 @@
#include "jinja_helpers.h"
#include "utils.h"
#include <fmt/format.h>
#include <QString>
#include <QUrl>
#include <memory>
#include <vector>
using namespace std::literals::string_view_literals;
JinjaResultInfo::~JinjaResultInfo() = default;
const JinjaFieldMap<ResultInfo> JinjaResultInfo::s_fields = {
{ "collection", [](auto &s) { return s.collection.toStdString(); } },
{ "path", [](auto &s) { return s.path .toStdString(); } },
{ "file", [](auto &s) { return s.file .toStdString(); } },
{ "title", [](auto &s) { return s.title .toStdString(); } },
{ "author", [](auto &s) { return s.author .toStdString(); } },
{ "date", [](auto &s) { return s.date .toStdString(); } },
{ "text", [](auto &s) { return s.text .toStdString(); } },
{ "page", [](auto &s) { return s.page; } },
{ "file_uri", [](auto &s) { return s.fileUri() .toStdString(); } },
};
JinjaPromptAttachment::~JinjaPromptAttachment() = default;
const JinjaFieldMap<PromptAttachment> JinjaPromptAttachment::s_fields = {
{ "url", [](auto &s) { return s.url.toString() .toStdString(); } },
{ "file", [](auto &s) { return s.file() .toStdString(); } },
{ "processed_content", [](auto &s) { return s.processedContent().toStdString(); } },
};
std::vector<std::string> JinjaMessage::GetKeys() const
{
std::vector<std::string> result;
auto &keys = this->keys();
result.reserve(keys.size());
result.assign(keys.begin(), keys.end());
return result;
}
auto JinjaMessage::keys() const -> const std::unordered_set<std::string_view> &
{
static const std::unordered_set<std::string_view> baseKeys
{ "role", "content" };
static const std::unordered_set<std::string_view> userKeys
{ "role", "content", "sources", "prompt_attachments" };
switch (m_item->type()) {
using enum ChatItem::Type;
case System:
case Response:
return baseKeys;
case Prompt:
return userKeys;
}
Q_UNREACHABLE();
}
bool operator==(const JinjaMessage &a, const JinjaMessage &b)
{
if (a.m_item == b.m_item)
return true;
const auto &[ia, ib] = std::tie(*a.m_item, *b.m_item);
auto type = ia.type();
if (type != ib.type() || ia.value != ib.value)
return false;
switch (type) {
using enum ChatItem::Type;
case System:
case Response:
return true;
case Prompt:
return ia.sources == ib.sources && ia.promptAttachments == ib.promptAttachments;
}
Q_UNREACHABLE();
}
const JinjaFieldMap<JinjaMessage> JinjaMessage::s_fields = {
{ "role", [](auto &m) {
switch (m.item().type()) {
using enum ChatItem::Type;
case System: return "system"sv;
case Prompt: return "user"sv;
case Response: return "assistant"sv;
}
Q_UNREACHABLE();
} },
{ "content", [](auto &m) {
if (m.version() == 0 && m.item().type() == ChatItem::Type::Prompt)
return m.item().bakedPrompt().toStdString();
return m.item().value.toStdString();
} },
{ "sources", [](auto &m) {
auto sources = m.item().sources | views::transform([](auto &r) {
return jinja2::GenericMap([map = std::make_shared<JinjaResultInfo>(r)] { return map.get(); });
});
return jinja2::ValuesList(sources.begin(), sources.end());
} },
{ "prompt_attachments", [](auto &m) {
auto attachments = m.item().promptAttachments | views::transform([](auto &pa) {
return jinja2::GenericMap([map = std::make_shared<JinjaPromptAttachment>(pa)] { return map.get(); });
});
return jinja2::ValuesList(attachments.begin(), attachments.end());
} },
};

View File

@ -0,0 +1,116 @@
#pragma once
#include "chatmodel.h"
#include "database.h"
#include <jinja2cpp/value.h>
#include <functional>
#include <ranges>
#include <string>
#include <string_view>
#include <unordered_map>
#include <unordered_set>
#include <QtGlobal>
namespace views = std::views;
template <typename T>
using JinjaFieldMap = std::unordered_map<std::string_view, std::function<jinja2::Value (const T &)>>;
template <typename Derived>
class JinjaComparable : public jinja2::IMapItemAccessor {
public:
JinjaComparable() = default;
bool IsEqual(const jinja2::IComparable &other) const override;
private:
Q_DISABLE_COPY_MOVE(JinjaComparable)
};
template <typename Derived>
class JinjaHelper : public JinjaComparable<Derived> {
public:
size_t GetSize() const override
{ return Derived::s_fields.size(); }
bool HasValue(const std::string &name) const override
{ return Derived::s_fields.contains(name); }
jinja2::Value GetValueByName(const std::string &name) const override;
std::vector<std::string> GetKeys() const override
{ auto keys = views::elements<0>(Derived::s_fields); return { keys.begin(), keys.end() }; }
};
class JinjaResultInfo : public JinjaHelper<JinjaResultInfo> {
public:
explicit JinjaResultInfo(const ResultInfo &source) noexcept
: m_source(&source) {}
~JinjaResultInfo() override;
const ResultInfo &value() const { return *m_source; }
friend bool operator==(const JinjaResultInfo &a, const JinjaResultInfo &b)
{ return a.m_source == b.m_source || *a.m_source == *b.m_source; }
private:
static const JinjaFieldMap<ResultInfo> s_fields;
const ResultInfo *m_source;
friend class JinjaHelper<JinjaResultInfo>;
};
class JinjaPromptAttachment : public JinjaHelper<JinjaPromptAttachment> {
public:
explicit JinjaPromptAttachment(const PromptAttachment &attachment) noexcept
: m_attachment(&attachment) {}
~JinjaPromptAttachment() override;
const PromptAttachment &value() const { return *m_attachment; }
friend bool operator==(const JinjaPromptAttachment &a, const JinjaPromptAttachment &b)
{ return a.m_attachment == b.m_attachment || *a.m_attachment == *b.m_attachment; }
private:
static const JinjaFieldMap<PromptAttachment> s_fields;
const PromptAttachment *m_attachment;
friend class JinjaHelper<JinjaPromptAttachment>;
};
class JinjaMessage : public JinjaHelper<JinjaMessage> {
public:
explicit JinjaMessage(uint version, const ChatItem &item) noexcept
: m_version(version), m_item(&item) {}
const JinjaMessage &value () const { return *this; }
uint version() const { return m_version; }
const ChatItem &item () const { return *m_item; }
size_t GetSize() const override { return keys().size(); }
bool HasValue(const std::string &name) const override { return keys().contains(name); }
jinja2::Value GetValueByName(const std::string &name) const override
{ return HasValue(name) ? JinjaHelper::GetValueByName(name) : jinja2::EmptyValue(); }
std::vector<std::string> GetKeys() const override;
private:
auto keys() const -> const std::unordered_set<std::string_view> &;
private:
static const JinjaFieldMap<JinjaMessage> s_fields;
uint m_version;
const ChatItem *m_item;
friend class JinjaHelper<JinjaMessage>;
friend bool operator==(const JinjaMessage &a, const JinjaMessage &b);
};
#include "jinja_helpers.inl"

View File

@ -0,0 +1,17 @@
template <typename D>
bool JinjaComparable<D>::IsEqual(const jinja2::IComparable &other) const
{
if (auto *omsg = dynamic_cast<const D *>(&other))
return *static_cast<const D *>(this) == *omsg;
return false;
}
template <typename D>
jinja2::Value JinjaHelper<D>::GetValueByName(const std::string &name) const
{
if (auto it = D::s_fields.find(name); it != D::s_fields.end()) {
auto [_, func] = *it;
return func(static_cast<const D *>(this)->value());
}
return jinja2::EmptyValue();
}

View File

@ -12,12 +12,16 @@
#include <singleapplication.h>
#include <QCoreApplication>
#include <QFont>
#include <QFontDatabase>
#include <QObject>
#include <QQmlApplicationEngine>
#include <QQmlContext>
#include <QQuickWindow>
#include <QSettings>
#include <QString>
#include <QUrl>
#include <QVariant>
#include <Qt>
#ifdef Q_OS_LINUX
@ -91,18 +95,22 @@ int main(int argc, char *argv[])
// Set the local and language translation before the qml engine has even been started. This will
// use the default system locale unless the user has explicitly set it to use a different one.
MySettings::globalInstance()->setLanguageAndLocale();
auto *mySettings = MySettings::globalInstance();
mySettings->setLanguageAndLocale();
QQmlApplicationEngine engine;
// Add a connection here from MySettings::languageAndLocaleChanged signal to a lambda slot where I can call
// engine.uiLanguage property
QObject::connect(MySettings::globalInstance(), &MySettings::languageAndLocaleChanged, [&engine]() {
QObject::connect(mySettings, &MySettings::languageAndLocaleChanged, [&engine]() {
engine.setUiLanguage(MySettings::globalInstance()->languageAndLocale());
});
qmlRegisterSingletonInstance("mysettings", 1, 0, "MySettings", MySettings::globalInstance());
qmlRegisterSingletonInstance("modellist", 1, 0, "ModelList", ModelList::globalInstance());
auto *modelList = ModelList::globalInstance();
QObject::connect(modelList, &ModelList::dataChanged, mySettings, &MySettings::onModelInfoChanged);
qmlRegisterSingletonInstance("mysettings", 1, 0, "MySettings", mySettings);
qmlRegisterSingletonInstance("modellist", 1, 0, "ModelList", modelList);
qmlRegisterSingletonInstance("chatlistmodel", 1, 0, "ChatListModel", ChatListModel::globalInstance());
qmlRegisterSingletonInstance("llm", 1, 0, "LLM", LLM::globalInstance());
qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance());
@ -110,6 +118,11 @@ int main(int argc, char *argv[])
qmlRegisterSingletonInstance("localdocs", 1, 0, "LocalDocs", LocalDocs::globalInstance());
qmlRegisterUncreatableMetaObject(MySettingsEnums::staticMetaObject, "mysettingsenums", 1, 0, "MySettingsEnums", "Error: only enums");
{
auto fixedFont = QFontDatabase::systemFont(QFontDatabase::FixedFont);
engine.rootContext()->setContextProperty("fixedFont", fixedFont);
}
const QUrl url(u"qrc:/gpt4all/main.qml"_s);
QObject::connect(&engine, &QQmlApplicationEngine::objectCreated,

View File

@ -316,26 +316,44 @@ void ModelInfo::setRepeatPenaltyTokens(int t)
m_repeatPenaltyTokens = t;
}
QString ModelInfo::promptTemplate() const
QVariant ModelInfo::defaultChatTemplate() const
{
return MySettings::globalInstance()->modelPromptTemplate(*this);
auto res = m_chatTemplate.or_else([this] -> std::optional<QString> {
if (!installed || isOnline)
return std::nullopt;
if (!m_modelChatTemplate) {
auto path = (dirpath + filename()).toUtf8();
auto res = LLModel::Implementation::chatTemplate(path.constData());
if (res) {
m_modelChatTemplate = QString::fromStdString(*res);
} else {
qWarning().nospace() << "failed to get chat template for " << filename() << ": " << res.error().c_str();
m_modelChatTemplate = QString(); // do not retry
}
}
if (m_modelChatTemplate->isNull())
return std::nullopt;
return m_modelChatTemplate;
});
if (res)
return std::move(*res);
return QVariant::fromValue(nullptr);
}
void ModelInfo::setPromptTemplate(const QString &t)
auto ModelInfo::chatTemplate() const -> UpgradeableSetting
{
if (shouldSaveMetadata()) MySettings::globalInstance()->setModelPromptTemplate(*this, t, true /*force*/);
m_promptTemplate = t;
return MySettings::globalInstance()->modelChatTemplate(*this);
}
QString ModelInfo::systemPrompt() const
QString ModelInfo::defaultSystemMessage() const
{
return MySettings::globalInstance()->modelSystemPrompt(*this);
return m_systemMessage;
}
void ModelInfo::setSystemPrompt(const QString &p)
auto ModelInfo::systemMessage() const -> UpgradeableSetting
{
if (shouldSaveMetadata()) MySettings::globalInstance()->setModelSystemPrompt(*this, p, true /*force*/);
m_systemPrompt = p;
return MySettings::globalInstance()->modelSystemMessage(*this);
}
QString ModelInfo::chatNamePrompt() const
@ -360,39 +378,41 @@ void ModelInfo::setSuggestedFollowUpPrompt(const QString &p)
m_suggestedFollowUpPrompt = p;
}
// FIXME(jared): this should not be used for model settings that have meaningful defaults, such as temperature
bool ModelInfo::shouldSaveMetadata() const
{
return installed && (isClone() || isDiscovered() || description() == "" /*indicates sideloaded*/);
}
QVariantMap ModelInfo::getFields() const
QVariant ModelInfo::getField(QLatin1StringView name) const
{
return {
{ "filename", m_filename },
{ "description", m_description },
{ "url", m_url },
{ "quant", m_quant },
{ "type", m_type },
{ "isClone", m_isClone },
{ "isDiscovered", m_isDiscovered },
{ "likes", m_likes },
{ "downloads", m_downloads },
{ "recency", m_recency },
{ "temperature", m_temperature },
{ "topP", m_topP },
{ "minP", m_minP },
{ "topK", m_topK },
{ "maxLength", m_maxLength },
{ "promptBatchSize", m_promptBatchSize },
{ "contextLength", m_contextLength },
{ "gpuLayers", m_gpuLayers },
{ "repeatPenalty", m_repeatPenalty },
{ "repeatPenaltyTokens", m_repeatPenaltyTokens },
{ "promptTemplate", m_promptTemplate },
{ "systemPrompt", m_systemPrompt },
{ "chatNamePrompt", m_chatNamePrompt },
{ "suggestedFollowUpPrompt", m_suggestedFollowUpPrompt },
static const std::unordered_map<QLatin1StringView, QVariant(*)(const ModelInfo &)> s_fields = {
{ "filename"_L1, [](auto &i) -> QVariant { return i.m_filename; } },
{ "description"_L1, [](auto &i) -> QVariant { return i.m_description; } },
{ "url"_L1, [](auto &i) -> QVariant { return i.m_url; } },
{ "quant"_L1, [](auto &i) -> QVariant { return i.m_quant; } },
{ "type"_L1, [](auto &i) -> QVariant { return i.m_type; } },
{ "isClone"_L1, [](auto &i) -> QVariant { return i.m_isClone; } },
{ "isDiscovered"_L1, [](auto &i) -> QVariant { return i.m_isDiscovered; } },
{ "likes"_L1, [](auto &i) -> QVariant { return i.m_likes; } },
{ "downloads"_L1, [](auto &i) -> QVariant { return i.m_downloads; } },
{ "recency"_L1, [](auto &i) -> QVariant { return i.m_recency; } },
{ "temperature"_L1, [](auto &i) -> QVariant { return i.m_temperature; } },
{ "topP"_L1, [](auto &i) -> QVariant { return i.m_topP; } },
{ "minP"_L1, [](auto &i) -> QVariant { return i.m_minP; } },
{ "topK"_L1, [](auto &i) -> QVariant { return i.m_topK; } },
{ "maxLength"_L1, [](auto &i) -> QVariant { return i.m_maxLength; } },
{ "promptBatchSize"_L1, [](auto &i) -> QVariant { return i.m_promptBatchSize; } },
{ "contextLength"_L1, [](auto &i) -> QVariant { return i.m_contextLength; } },
{ "gpuLayers"_L1, [](auto &i) -> QVariant { return i.m_gpuLayers; } },
{ "repeatPenalty"_L1, [](auto &i) -> QVariant { return i.m_repeatPenalty; } },
{ "repeatPenaltyTokens"_L1, [](auto &i) -> QVariant { return i.m_repeatPenaltyTokens; } },
{ "chatTemplate"_L1, [](auto &i) -> QVariant { return i.defaultChatTemplate(); } },
{ "systemMessage"_L1, [](auto &i) -> QVariant { return i.m_systemMessage; } },
{ "chatNamePrompt"_L1, [](auto &i) -> QVariant { return i.m_chatNamePrompt; } },
{ "suggestedFollowUpPrompt"_L1, [](auto &i) -> QVariant { return i.m_suggestedFollowUpPrompt; } },
};
return s_fields.at(name)(*this);
}
InstalledModels::InstalledModels(QObject *parent, bool selectable)
@ -491,31 +511,48 @@ ModelList::ModelList()
m_selectableModels->setSourceModel(this);
m_downloadableModels->setSourceModel(this);
connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromDirectory);
connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromJson);
connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromSettings);
connect(MySettings::globalInstance(), &MySettings::nameChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::topPChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::minPChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::gpuLayersChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::systemPromptChanged, this, &ModelList::updateDataForSettings);
auto *mySettings = MySettings::globalInstance();
connect(mySettings, &MySettings::nameChanged, this, &ModelList::updateDataForSettings );
connect(mySettings, &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings );
connect(mySettings, &MySettings::topPChanged, this, &ModelList::updateDataForSettings );
connect(mySettings, &MySettings::minPChanged, this, &ModelList::updateDataForSettings );
connect(mySettings, &MySettings::topKChanged, this, &ModelList::updateDataForSettings );
connect(mySettings, &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings );
connect(mySettings, &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings );
connect(mySettings, &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings );
connect(mySettings, &MySettings::gpuLayersChanged, this, &ModelList::updateDataForSettings );
connect(mySettings, &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings );
connect(mySettings, &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings );
connect(mySettings, &MySettings::chatTemplateChanged, this, &ModelList::maybeUpdateDataForSettings);
connect(mySettings, &MySettings::systemMessageChanged, this, &ModelList::maybeUpdateDataForSettings);
connect(this, &ModelList::dataChanged, this, &ModelList::onDataChanged);
connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, &ModelList::handleSslErrors);
updateModelsFromJson();
updateModelsFromSettings();
updateModelsFromDirectory();
connect(mySettings, &MySettings::modelPathChanged, this, &ModelList::updateModelsFromDirectory);
connect(mySettings, &MySettings::modelPathChanged, this, &ModelList::updateModelsFromJson );
connect(mySettings, &MySettings::modelPathChanged, this, &ModelList::updateModelsFromSettings );
QCoreApplication::instance()->installEventFilter(this);
}
// an easier way to listen for model info and setting changes
void ModelList::onDataChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList<int> &roles)
{
Q_UNUSED(roles)
for (int row = topLeft.row(); row <= bottomRight.row(); row++) {
auto index = topLeft.siblingAtRow(row);
auto id = index.data(ModelList::IdRole).toString();
if (auto info = modelInfo(id); !info.id().isNull())
emit modelInfoChanged(info);
}
}
QString ModelList::compatibleModelNameHash(QUrl baseUrl, QString modelName) {
QCryptographicHash sha256(QCryptographicHash::Sha256);
sha256.addData((baseUrl.toString() + "_" + modelName).toUtf8());
@ -776,10 +813,10 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
return info->repeatPenalty();
case RepeatPenaltyTokensRole:
return info->repeatPenaltyTokens();
case PromptTemplateRole:
return info->promptTemplate();
case SystemPromptRole:
return info->systemPrompt();
case ChatTemplateRole:
return QVariant::fromValue(info->chatTemplate());
case SystemMessageRole:
return QVariant::fromValue(info->systemMessage());
case ChatNamePromptRole:
return info->chatNamePrompt();
case SuggestedFollowUpPromptRole:
@ -952,10 +989,10 @@ void ModelList::updateData(const QString &id, const QVector<QPair<int, QVariant>
info->setRepeatPenalty(value.toDouble()); break;
case RepeatPenaltyTokensRole:
info->setRepeatPenaltyTokens(value.toInt()); break;
case PromptTemplateRole:
info->setPromptTemplate(value.toString()); break;
case SystemPromptRole:
info->setSystemPrompt(value.toString()); break;
case ChatTemplateRole:
info->m_chatTemplate = value.toString(); break;
case SystemMessageRole:
info->m_systemMessage = value.toString(); break;
case ChatNamePromptRole:
info->setChatNamePrompt(value.toString()); break;
case SuggestedFollowUpPromptRole:
@ -1056,11 +1093,11 @@ ModelInfo ModelList::modelInfo(const QString &id) const
return *m_modelMap.value(id);
}
ModelInfo ModelList::modelInfoByFilename(const QString &filename) const
ModelInfo ModelList::modelInfoByFilename(const QString &filename, bool allowClone) const
{
QMutexLocker locker(&m_mutex);
for (ModelInfo *info : m_models)
if (info->filename() == filename)
if (info->filename() == filename && (allowClone || !info->isClone()))
return *info;
return ModelInfo();
}
@ -1080,6 +1117,20 @@ QString ModelList::clone(const ModelInfo &model)
const QString id = Network::globalInstance()->generateUniqueId();
addModel(id);
QString chatTemplate, systemMessage;
if (auto tmpl = model.chatTemplate().asModern()) {
chatTemplate = *tmpl;
} else {
qWarning("ModelList Warning: attempted to clone model with legacy chat template");
return {};
}
if (auto msg = model.systemMessage().asModern()) {
systemMessage = *msg;
} else {
qWarning("ModelList Warning: attempted to clone model with legacy system message");
return {};
}
QVector<QPair<int, QVariant>> data {
{ ModelList::InstalledRole, model.installed },
{ ModelList::IsCloneRole, true },
@ -1099,8 +1150,8 @@ QString ModelList::clone(const ModelInfo &model)
{ ModelList::GpuLayersRole, model.gpuLayers() },
{ ModelList::RepeatPenaltyRole, model.repeatPenalty() },
{ ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens() },
{ ModelList::PromptTemplateRole, model.promptTemplate() },
{ ModelList::SystemPromptRole, model.systemPrompt() },
{ ModelList::ChatTemplateRole, chatTemplate },
{ ModelList::SystemMessageRole, systemMessage },
{ ModelList::ChatNamePromptRole, model.chatNamePrompt() },
{ ModelList::SuggestedFollowUpPromptRole, model.suggestedFollowUpPrompt() },
};
@ -1125,21 +1176,23 @@ void ModelList::removeInstalled(const ModelInfo &model)
removeInternal(model);
}
int ModelList::indexByModelId(const QString &id) const
{
QMutexLocker locker(&m_mutex);
if (auto it = m_modelMap.find(id); it != m_modelMap.cend())
return m_models.indexOf(*it);
return -1;
}
void ModelList::removeInternal(const ModelInfo &model)
{
const bool hasModel = contains(model.id());
Q_ASSERT(hasModel);
if (!hasModel) {
int indexOfModel = indexByModelId(model.id());
Q_ASSERT(indexOfModel != -1);
if (indexOfModel == -1) {
qWarning() << "ERROR: model list does not contain" << model.id();
return;
}
int indexOfModel = 0;
{
QMutexLocker locker(&m_mutex);
ModelInfo *info = m_modelMap.value(model.id());
indexOfModel = m_models.indexOf(info);
}
beginRemoveRows(QModelIndex(), indexOfModel, indexOfModel);
{
QMutexLocker locker(&m_mutex);
@ -1314,8 +1367,6 @@ void ModelList::processModelDirectory(const QString &path)
// The description is hard-coded into "GPT4All.ini" due to performance issue.
// If the description goes to be dynamic from its .rmodel file, it will get high I/O usage while using the ModelList.
data.append({ DescriptionRole, description });
// Prompt template should be clear while using ChatML format which is using in most of OpenAI-Compatible API server.
data.append({ PromptTemplateRole, "%1" });
}
updateData(id, data);
}
@ -1451,9 +1502,20 @@ void ModelList::handleSslErrors(QNetworkReply *reply, const QList<QSslError> &er
qWarning() << "ERROR: Received ssl error:" << e.errorString() << "for" << url;
}
void ModelList::maybeUpdateDataForSettings(const ModelInfo &info, bool fromInfo)
{
// ignore updates that were *because* of a dataChanged - would cause a circular dependency
int idx;
if (!fromInfo && (idx = indexByModelId(info.id())) != -1) {
emit dataChanged(index(idx, 0), index(idx, 0));
emit selectableModelListChanged();
}
}
void ModelList::updateDataForSettings()
{
emit dataChanged(index(0, 0), index(m_models.size() - 1, 0));
emit selectableModelListChanged();
}
void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
@ -1560,10 +1622,10 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
data.append({ ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble() });
if (obj.contains("repeatPenaltyTokens"))
data.append({ ModelList::RepeatPenaltyTokensRole, obj["repeatPenaltyTokens"].toInt() });
if (obj.contains("promptTemplate"))
data.append({ ModelList::PromptTemplateRole, obj["promptTemplate"].toString() });
if (obj.contains("systemPrompt"))
data.append({ ModelList::SystemPromptRole, obj["systemPrompt"].toString() });
if (auto it = obj.find("chatTemplate"_L1); it != obj.end())
data.append({ ModelList::ChatTemplateRole, it->toString() });
if (auto it = obj.find("systemMessage"_L1); it != obj.end())
data.append({ ModelList::SystemMessageRole, it->toString() });
updateData(id, data);
}
@ -1755,6 +1817,9 @@ void ModelList::updateDiscoveredInstalled(const ModelInfo &info)
updateData(info.id(), data);
}
// FIXME(jared): This should only contain fields without reasonable defaults such as name, description, and URL.
// For other settings, there is no authoritative value and we should load the setting lazily like we do
// for any other override.
void ModelList::updateModelsFromSettings()
{
QSettings settings;
@ -1769,12 +1834,27 @@ void ModelList::updateModelsFromSettings()
// If we can't find the corresponding file, then ignore it as this reflects a stale model.
// The file could have been deleted manually by the user for instance or temporarily renamed.
if (!settings.contains(g + "/filename") || !modelExists(settings.value(g + "/filename").toString()))
continue;
QString filename;
{
auto value = settings.value(u"%1/filename"_s.arg(g));
if (!value.isValid() || !modelExists(filename = value.toString()))
continue;
}
QVector<QPair<int, QVariant>> data;
// load data from base model
// FIXME(jared): how does "Restore Defaults" work for other settings of clones which we don't do this for?
if (auto base = modelInfoByFilename(filename, /*allowClone*/ false); !base.id().isNull()) {
if (auto tmpl = base.m_chatTemplate)
data.append({ ModelList::ChatTemplateRole, *tmpl });
if (auto msg = base.m_systemMessage; !msg.isNull())
data.append({ ModelList::SystemMessageRole, msg });
}
addModel(id);
QVector<QPair<int, QVariant>> data;
// load data from settings
if (settings.contains(g + "/name")) {
const QString name = settings.value(g + "/name").toString();
data.append({ ModelList::NameRole, name });
@ -1859,14 +1939,6 @@ void ModelList::updateModelsFromSettings()
const int repeatPenaltyTokens = settings.value(g + "/repeatPenaltyTokens").toInt();
data.append({ ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens });
}
if (settings.contains(g + "/promptTemplate")) {
const QString promptTemplate = settings.value(g + "/promptTemplate").toString();
data.append({ ModelList::PromptTemplateRole, promptTemplate });
}
if (settings.contains(g + "/systemPrompt")) {
const QString systemPrompt = settings.value(g + "/systemPrompt").toString();
data.append({ ModelList::SystemPromptRole, systemPrompt });
}
if (settings.contains(g + "/chatNamePrompt")) {
const QString chatNamePrompt = settings.value(g + "/chatNamePrompt").toString();
data.append({ ModelList::ChatNamePromptRole, chatNamePrompt });

View File

@ -5,12 +5,14 @@
#include <QByteArray>
#include <QDateTime>
#include <QHash>
#include <QLatin1StringView>
#include <QList>
#include <QMutex>
#include <QNetworkAccessManager>
#include <QNetworkReply>
#include <QObject>
#include <QPair>
#include <QQmlEngine>
#include <QSortFilterProxyModel>
#include <QSslError>
#include <QString>
@ -19,11 +21,53 @@
#include <Qt>
#include <QtGlobal>
#include <optional>
#include <utility>
using namespace Qt::Literals::StringLiterals;
class UpgradeableSetting {
Q_GADGET
QML_ANONYMOUS
// NOTE: Unset implies there is neither a value nor a default
enum class State { Unset, Legacy, Modern };
Q_PROPERTY(bool isSet READ isSet )
Q_PROPERTY(bool isLegacy READ isLegacy)
Q_PROPERTY(bool isModern READ isModern)
Q_PROPERTY(QVariant value READ value) // string or null
public:
struct legacy_tag_t { explicit legacy_tag_t() = default; };
static inline constexpr legacy_tag_t legacy_tag = legacy_tag_t();
UpgradeableSetting() : m_state(State::Unset ) {}
UpgradeableSetting(legacy_tag_t, QString value): m_state(State::Legacy), m_value(std::move(value)) {}
UpgradeableSetting( QString value): m_state(State::Modern), m_value(std::move(value)) {}
bool isSet () const { return m_state != State::Unset; }
bool isLegacy() const { return m_state == State::Legacy; }
bool isModern() const { return m_state == State::Modern; }
QVariant value () const { return m_state == State::Unset ? QVariant::fromValue(nullptr) : m_value; }
friend bool operator==(const UpgradeableSetting &a, const UpgradeableSetting &b)
{ return a.m_state == b.m_state && (a.m_state == State::Unset || a.m_value == b.m_value); }
// returns std::nullopt if there is a legacy template or it is not set
std::optional<QString> asModern() const
{
if (m_state == State::Modern)
return m_value;
return std::nullopt;
}
private:
State m_state;
QString m_value;
};
struct ModelInfo {
Q_GADGET
Q_PROPERTY(QString id READ id WRITE setId)
@ -69,8 +113,11 @@ struct ModelInfo {
Q_PROPERTY(int maxGpuLayers READ maxGpuLayers)
Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty)
Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens)
Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate)
Q_PROPERTY(QString systemPrompt READ systemPrompt WRITE setSystemPrompt)
// user-defined chat template and system message must be written through settings because of their legacy compat
Q_PROPERTY(QVariant defaultChatTemplate READ defaultChatTemplate )
Q_PROPERTY(UpgradeableSetting chatTemplate READ chatTemplate )
Q_PROPERTY(QString defaultSystemMessage READ defaultSystemMessage)
Q_PROPERTY(UpgradeableSetting systemMessage READ systemMessage )
Q_PROPERTY(QString chatNamePrompt READ chatNamePrompt WRITE setChatNamePrompt)
Q_PROPERTY(QString suggestedFollowUpPrompt READ suggestedFollowUpPrompt WRITE setSuggestedFollowUpPrompt)
Q_PROPERTY(int likes READ likes WRITE setLikes)
@ -178,19 +225,22 @@ public:
void setRepeatPenalty(double p);
int repeatPenaltyTokens() const;
void setRepeatPenaltyTokens(int t);
QString promptTemplate() const;
void setPromptTemplate(const QString &t);
QString systemPrompt() const;
void setSystemPrompt(const QString &p);
QVariant defaultChatTemplate() const;
UpgradeableSetting chatTemplate() const;
QString defaultSystemMessage() const;
UpgradeableSetting systemMessage() const;
QString chatNamePrompt() const;
void setChatNamePrompt(const QString &p);
QString suggestedFollowUpPrompt() const;
void setSuggestedFollowUpPrompt(const QString &p);
// Some metadata must be saved to settings because it does not have a meaningful default from some other source.
// This is useful for fields such as name, description, and URL.
// It is true for any models that have not been installed from models.json.
bool shouldSaveMetadata() const;
private:
QVariantMap getFields() const;
QVariant getField(QLatin1StringView name) const;
QString m_id;
QString m_name;
@ -216,11 +266,13 @@ private:
mutable int m_maxGpuLayers = -1;
double m_repeatPenalty = 1.18;
int m_repeatPenaltyTokens = 64;
QString m_promptTemplate = "### Human:\n%1\n\n### Assistant:\n";
QString m_systemPrompt = "### System:\nYou are an AI assistant who gives a quality response to whatever humans ask of you.\n\n";
std::optional<QString> m_chatTemplate;
mutable std::optional<QString> m_modelChatTemplate;
QString m_systemMessage;
QString m_chatNamePrompt = "Describe the above conversation in seven words or less.";
QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts.";
friend class MySettings;
friend class ModelList;
};
Q_DECLARE_METATYPE(ModelInfo)
@ -340,8 +392,8 @@ public:
GpuLayersRole,
RepeatPenaltyRole,
RepeatPenaltyTokensRole,
PromptTemplateRole,
SystemPromptRole,
ChatTemplateRole,
SystemMessageRole,
ChatNamePromptRole,
SuggestedFollowUpPromptRole,
MinPRole,
@ -394,8 +446,8 @@ public:
roles[GpuLayersRole] = "gpuLayers";
roles[RepeatPenaltyRole] = "repeatPenalty";
roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens";
roles[PromptTemplateRole] = "promptTemplate";
roles[SystemPromptRole] = "systemPrompt";
roles[ChatTemplateRole] = "chatTemplate";
roles[SystemMessageRole] = "systemMessage";
roles[ChatNamePromptRole] = "chatNamePrompt";
roles[SuggestedFollowUpPromptRole] = "suggestedFollowUpPrompt";
roles[LikesRole] = "likes";
@ -416,7 +468,7 @@ public:
bool contains(const QString &id) const;
bool containsByFilename(const QString &filename) const;
Q_INVOKABLE ModelInfo modelInfo(const QString &id) const;
Q_INVOKABLE ModelInfo modelInfoByFilename(const QString &filename) const;
Q_INVOKABLE ModelInfo modelInfoByFilename(const QString &filename, bool allowClone = true) const;
Q_INVOKABLE bool isUniqueName(const QString &name) const;
Q_INVOKABLE QString clone(const ModelInfo &model);
Q_INVOKABLE void removeClone(const ModelInfo &model);
@ -476,15 +528,18 @@ Q_SIGNALS:
void discoverSortChanged();
void discoverProgressChanged();
void discoverInProgressChanged();
void modelInfoChanged(const ModelInfo &info);
protected:
bool eventFilter(QObject *obj, QEvent *ev) override;
private Q_SLOTS:
void onDataChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList<int> &roles);
void resortModel();
void updateModelsFromJson();
void updateModelsFromJsonAsync();
void updateModelsFromSettings();
void maybeUpdateDataForSettings(const ModelInfo &info, bool fromInfo);
void updateDataForSettings();
void handleModelsJsonDownloadFinished();
void handleModelsJsonDownloadErrorOccurred(QNetworkReply::NetworkError code);
@ -495,6 +550,9 @@ private Q_SLOTS:
void handleSslErrors(QNetworkReply *reply, const QList<QSslError> &errors);
private:
// Return the index of the model with the given id, or -1 if not found.
int indexByModelId(const QString &id) const;
void removeInternal(const ModelInfo &model);
void clearDiscoveredModels();
bool modelExists(const QString &fileName) const;

View File

@ -1,5 +1,8 @@
#include "mysettings.h"
#include "chatllm.h"
#include "modellist.h"
#include <gpt4all-backend/llmodel.h>
#include <QDebug>
@ -29,8 +32,13 @@ static const QStringList suggestionModeNames { "LocalDocsOnly", "On", "Off" };
static const QStringList chatThemeNames { "Light", "Dark", "LegacyDark" };
static const QStringList fontSizeNames { "Small", "Medium", "Large" };
// FIXME: All of these default strings that are shown in the UI for settings need to be marked as
// translatable
// psuedo-enum
namespace ModelSettingsKey { namespace {
auto ChatTemplate = "chatTemplate"_L1;
auto PromptTemplate = "promptTemplate"_L1; // legacy
auto SystemMessage = "systemMessage"_L1;
auto SystemPrompt = "systemPrompt"_L1; // legacy
} } // namespace ModelSettingsKey::(anonymous)
namespace defaults {
@ -48,7 +56,6 @@ static const QVariantMap basicDefaults {
{ "fontSize", QVariant::fromValue(FontSize::Small) },
{ "lastVersionStarted", "" },
{ "networkPort", 4891, },
{ "saveChatsContext", false },
{ "systemTray", false },
{ "serverChat", false },
{ "userDefaultModel", "Application default" },
@ -147,6 +154,11 @@ static QStringList getUiLanguages(const QString &modelPath)
return languageList;
}
static QString modelSettingName(const ModelInfo &info, auto &&name)
{
return u"model-%1/%2"_s.arg(info.id(), name);
}
class MyPrivateSettings: public MySettings { };
Q_GLOBAL_STATIC(MyPrivateSettings, settingsInstance)
MySettings *MySettings::globalInstance()
@ -162,6 +174,34 @@ MySettings::MySettings()
{
}
QVariant MySettings::checkJinjaTemplateError(const QString &tmpl)
{
if (auto err = ChatLLM::checkJinjaTemplateError(tmpl.toStdString()))
return QString::fromStdString(*err);
return QVariant::fromValue(nullptr);
}
// Unset settings come from ModelInfo. Listen for changes so we can emit our own setting-specific signals.
void MySettings::onModelInfoChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList<int> &roles)
{
auto settingChanged = [&](const auto &info, auto role, const auto &name) {
return (roles.isEmpty() || roles.contains(role)) && !m_settings.contains(modelSettingName(info, name));
};
auto &modelList = dynamic_cast<const ModelList &>(*QObject::sender());
for (int row = topLeft.row(); row <= bottomRight.row(); row++) {
using enum ModelList::Roles;
using namespace ModelSettingsKey;
auto index = topLeft.siblingAtRow(row);
if (auto info = modelList.modelInfo(index.data(IdRole).toString()); !info.id().isNull()) {
if (settingChanged(info, ChatTemplateRole, ChatTemplate))
emit chatTemplateChanged(info, /*fromInfo*/ true);
if (settingChanged(info, SystemMessageRole, SystemMessage))
emit systemMessageChanged(info, /*fromInfo*/ true);
}
}
}
QVariant MySettings::getBasicSetting(const QString &name) const
{
return m_settings.value(name, basicDefaults.value(name));
@ -194,8 +234,8 @@ void MySettings::restoreModelDefaults(const ModelInfo &info)
setModelGpuLayers(info, info.m_gpuLayers);
setModelRepeatPenalty(info, info.m_repeatPenalty);
setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens);
setModelPromptTemplate(info, info.m_promptTemplate);
setModelSystemPrompt(info, info.m_systemPrompt);
resetModelChatTemplate (info);
resetModelSystemMessage(info);
setModelChatNamePrompt(info, info.m_chatNamePrompt);
setModelSuggestedFollowUpPrompt(info, info.m_suggestedFollowUpPrompt);
}
@ -206,7 +246,6 @@ void MySettings::restoreApplicationDefaults()
setFontSize(basicDefaults.value("fontSize").value<FontSize>());
setDevice(defaults::device);
setThreadCount(defaults::threadCount);
setSaveChatsContext(basicDefaults.value("saveChatsContext").toBool());
setSystemTray(basicDefaults.value("systemTray").toBool());
setServerChat(basicDefaults.value("serverChat").toBool());
setNetworkPort(basicDefaults.value("networkPort").toInt());
@ -252,29 +291,37 @@ void MySettings::setModelName(const ModelInfo &info, const QString &value, bool
emit nameChanged(info);
}
static QString modelSettingName(const ModelInfo &info, const QString &name)
QVariant MySettings::getModelSetting(QLatin1StringView name, const ModelInfo &info) const
{
return u"model-%1/%2"_s.arg(info.id(), name);
QLatin1StringView nameL1(name);
return m_settings.value(modelSettingName(info, nameL1), info.getField(nameL1));
}
QVariant MySettings::getModelSetting(const QString &name, const ModelInfo &info) const
QVariant MySettings::getModelSetting(const char *name, const ModelInfo &info) const
{
return m_settings.value(modelSettingName(info, name), info.getFields().value(name));
return getModelSetting(QLatin1StringView(name), info);
}
void MySettings::setModelSetting(const QString &name, const ModelInfo &info, const QVariant &value, bool force,
void MySettings::setModelSetting(QLatin1StringView name, const ModelInfo &info, const QVariant &value, bool force,
bool signal)
{
if (!force && (info.id().isEmpty() || getModelSetting(name, info) == value))
return;
QString settingName = modelSettingName(info, name);
if (info.getFields().value(name) == value && !info.shouldSaveMetadata())
QLatin1StringView nameL1(name);
QString settingName = modelSettingName(info, nameL1);
if (info.getField(nameL1) == value && !info.shouldSaveMetadata())
m_settings.remove(settingName);
else
m_settings.setValue(settingName, value);
if (signal && !force)
QMetaObject::invokeMethod(this, u"%1Changed"_s.arg(name).toLatin1().constData(), Q_ARG(ModelInfo, info));
QMetaObject::invokeMethod(this, u"%1Changed"_s.arg(nameL1).toLatin1().constData(), Q_ARG(ModelInfo, info));
}
void MySettings::setModelSetting(const char *name, const ModelInfo &info, const QVariant &value, bool force,
bool signal)
{
setModelSetting(QLatin1StringView(name), info, value, force, signal);
}
QString MySettings::modelFilename (const ModelInfo &info) const { return getModelSetting("filename", info).toString(); }
@ -297,11 +344,68 @@ int MySettings::modelContextLength (const ModelInfo &info) const
int MySettings::modelGpuLayers (const ModelInfo &info) const { return getModelSetting("gpuLayers", info).toInt(); }
double MySettings::modelRepeatPenalty (const ModelInfo &info) const { return getModelSetting("repeatPenalty", info).toDouble(); }
int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); }
QString MySettings::modelPromptTemplate (const ModelInfo &info) const { return getModelSetting("promptTemplate", info).toString(); }
QString MySettings::modelSystemPrompt (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); }
QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); }
QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); }
auto MySettings::getUpgradeableModelSetting(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
) const -> UpgradeableSetting
{
if (info.id().isEmpty()) {
qWarning("%s: got null model", Q_FUNC_INFO);
return {};
}
auto value = m_settings.value(modelSettingName(info, legacyKey));
if (value.isValid())
return { UpgradeableSetting::legacy_tag, value.toString() };
value = getModelSetting(newKey, info);
if (!value.isNull())
return value.toString();
return {}; // neither a default nor an override
}
bool MySettings::isUpgradeableModelSettingSet(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
) const
{
if (info.id().isEmpty()) {
qWarning("%s: got null model", Q_FUNC_INFO);
return false;
}
if (m_settings.contains(modelSettingName(info, legacyKey)))
return true;
// NOTE: unlike getUpgradeableSetting(), this ignores the default
return m_settings.contains(modelSettingName(info, newKey));
}
auto MySettings::modelChatTemplate(const ModelInfo &info) const -> UpgradeableSetting
{
using namespace ModelSettingsKey;
return getUpgradeableModelSetting(info, PromptTemplate, ChatTemplate);
}
bool MySettings::isModelChatTemplateSet(const ModelInfo &info) const
{
using namespace ModelSettingsKey;
return isUpgradeableModelSettingSet(info, PromptTemplate, ChatTemplate);
}
auto MySettings::modelSystemMessage(const ModelInfo &info) const -> UpgradeableSetting
{
using namespace ModelSettingsKey;
return getUpgradeableModelSetting(info, SystemPrompt, SystemMessage);
}
bool MySettings::isModelSystemMessageSet(const ModelInfo &info) const
{
using namespace ModelSettingsKey;
return isUpgradeableModelSettingSet(info, SystemPrompt, SystemMessage);
}
void MySettings::setModelFilename(const ModelInfo &info, const QString &value, bool force)
{
setModelSetting("filename", info, value, force, true);
@ -402,14 +506,77 @@ void MySettings::setModelRepeatPenaltyTokens(const ModelInfo &info, int value, b
setModelSetting("repeatPenaltyTokens", info, value, force, true);
}
void MySettings::setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force)
{
setModelSetting("promptTemplate", info, value, force, true);
bool MySettings::setUpgradeableModelSetting(
const ModelInfo &info, const QString &value, QLatin1StringView legacyKey, QLatin1StringView newKey
) {
if (info.id().isEmpty()) {
qWarning("%s: got null model", Q_FUNC_INFO);
return false;
}
auto legacyModelKey = modelSettingName(info, legacyKey);
auto newModelKey = modelSettingName(info, newKey );
bool changed = false;
if (m_settings.contains(legacyModelKey)) {
m_settings.remove(legacyModelKey);
changed = true;
}
auto oldValue = m_settings.value(newModelKey);
if (!oldValue.isValid() || oldValue.toString() != value) {
m_settings.setValue(newModelKey, value);
changed = true;
}
return changed;
}
void MySettings::setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force)
bool MySettings::resetUpgradeableModelSetting(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
) {
if (info.id().isEmpty()) {
qWarning("%s: got null model", Q_FUNC_INFO);
return false;
}
auto legacyModelKey = modelSettingName(info, legacyKey);
auto newModelKey = modelSettingName(info, newKey );
bool changed = false;
if (m_settings.contains(legacyModelKey)) {
m_settings.remove(legacyModelKey);
changed = true;
}
if (m_settings.contains(newModelKey)) {
m_settings.remove(newModelKey);
changed = true;
}
return changed;
}
void MySettings::setModelChatTemplate(const ModelInfo &info, const QString &value)
{
setModelSetting("systemPrompt", info, value, force, true);
using namespace ModelSettingsKey;
if (setUpgradeableModelSetting(info, value, PromptTemplate, ChatTemplate))
emit chatTemplateChanged(info);
}
void MySettings::resetModelChatTemplate(const ModelInfo &info)
{
using namespace ModelSettingsKey;
if (resetUpgradeableModelSetting(info, PromptTemplate, ChatTemplate))
emit chatTemplateChanged(info);
}
void MySettings::setModelSystemMessage(const ModelInfo &info, const QString &value)
{
using namespace ModelSettingsKey;
if (setUpgradeableModelSetting(info, value, SystemPrompt, SystemMessage))
emit systemMessageChanged(info);
}
void MySettings::resetModelSystemMessage(const ModelInfo &info)
{
using namespace ModelSettingsKey;
if (resetUpgradeableModelSetting(info, SystemPrompt, SystemMessage))
emit systemMessageChanged(info);
}
void MySettings::setModelChatNamePrompt(const ModelInfo &info, const QString &value, bool force)
@ -445,7 +612,6 @@ void MySettings::setThreadCount(int value)
emit threadCountChanged();
}
bool MySettings::saveChatsContext() const { return getBasicSetting("saveChatsContext" ).toBool(); }
bool MySettings::systemTray() const { return getBasicSetting("systemTray" ).toBool(); }
bool MySettings::serverChat() const { return getBasicSetting("serverChat" ).toBool(); }
int MySettings::networkPort() const { return getBasicSetting("networkPort" ).toInt(); }
@ -464,7 +630,6 @@ ChatTheme MySettings::chatTheme() const { return ChatTheme (getEnu
FontSize MySettings::fontSize() const { return FontSize (getEnumSetting("fontSize", fontSizeNames)); }
SuggestionMode MySettings::suggestionMode() const { return SuggestionMode(getEnumSetting("suggestionMode", suggestionModeNames)); }
void MySettings::setSaveChatsContext(bool value) { setBasicSetting("saveChatsContext", value); }
void MySettings::setSystemTray(bool value) { setBasicSetting("systemTray", value); }
void MySettings::setServerChat(bool value) { setBasicSetting("serverChat", value); }
void MySettings::setNetworkPort(int value) { setBasicSetting("networkPort", value); }

View File

@ -4,6 +4,9 @@
#include "modellist.h" // IWYU pragma: keep
#include <QDateTime>
#include <QLatin1StringView>
#include <QList>
#include <QModelIndex>
#include <QObject>
#include <QSettings>
#include <QString>
@ -48,7 +51,6 @@ class MySettings : public QObject
{
Q_OBJECT
Q_PROPERTY(int threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool saveChatsContext READ saveChatsContext WRITE setSaveChatsContext NOTIFY saveChatsContextChanged)
Q_PROPERTY(bool systemTray READ systemTray WRITE setSystemTray NOTIFY systemTrayChanged)
Q_PROPERTY(bool serverChat READ serverChat WRITE setServerChat NOTIFY serverChatChanged)
Q_PROPERTY(QString modelPath READ modelPath WRITE setModelPath NOTIFY modelPathChanged)
@ -75,9 +77,18 @@ class MySettings : public QObject
Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged)
Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT)
private:
explicit MySettings();
~MySettings() override = default;
public Q_SLOTS:
void onModelInfoChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList<int> &roles = {});
public:
static MySettings *globalInstance();
Q_INVOKABLE static QVariant checkJinjaTemplateError(const QString &tmpl);
// Restore methods
Q_INVOKABLE void restoreModelDefaults(const ModelInfo &info);
Q_INVOKABLE void restoreApplicationDefaults();
@ -125,10 +136,14 @@ public:
Q_INVOKABLE void setModelRepeatPenalty(const ModelInfo &info, double value, bool force = false);
int modelRepeatPenaltyTokens(const ModelInfo &info) const;
Q_INVOKABLE void setModelRepeatPenaltyTokens(const ModelInfo &info, int value, bool force = false);
QString modelPromptTemplate(const ModelInfo &info) const;
Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force = false);
QString modelSystemPrompt(const ModelInfo &info) const;
Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force = false);
auto modelChatTemplate(const ModelInfo &info) const -> UpgradeableSetting;
Q_INVOKABLE bool isModelChatTemplateSet(const ModelInfo &info) const;
Q_INVOKABLE void setModelChatTemplate(const ModelInfo &info, const QString &value);
Q_INVOKABLE void resetModelChatTemplate(const ModelInfo &info);
auto modelSystemMessage(const ModelInfo &info) const -> UpgradeableSetting;
Q_INVOKABLE bool isModelSystemMessageSet(const ModelInfo &info) const;
Q_INVOKABLE void setModelSystemMessage(const ModelInfo &info, const QString &value);
Q_INVOKABLE void resetModelSystemMessage(const ModelInfo &info);
int modelContextLength(const ModelInfo &info) const;
Q_INVOKABLE void setModelContextLength(const ModelInfo &info, int value, bool force = false);
int modelGpuLayers(const ModelInfo &info) const;
@ -141,8 +156,6 @@ public:
// Application settings
int threadCount() const;
void setThreadCount(int value);
bool saveChatsContext() const;
void setSaveChatsContext(bool value);
bool systemTray() const;
void setSystemTray(bool value);
bool serverChat() const;
@ -215,12 +228,11 @@ Q_SIGNALS:
void gpuLayersChanged(const ModelInfo &info);
void repeatPenaltyChanged(const ModelInfo &info);
void repeatPenaltyTokensChanged(const ModelInfo &info);
void promptTemplateChanged(const ModelInfo &info);
void systemPromptChanged(const ModelInfo &info);
void chatTemplateChanged(const ModelInfo &info, bool fromInfo = false);
void systemMessageChanged(const ModelInfo &info, bool fromInfo = false);
void chatNamePromptChanged(const ModelInfo &info);
void suggestedFollowUpPromptChanged(const ModelInfo &info);
void threadCountChanged();
void saveChatsContextChanged();
void systemTrayChanged();
void serverChatChanged();
void modelPathChanged();
@ -245,6 +257,30 @@ Q_SIGNALS:
void suggestionModeChanged();
void languageAndLocaleChanged();
private:
QVariant getBasicSetting(const QString &name) const;
void setBasicSetting(const QString &name, const QVariant &value, std::optional<QString> signal = std::nullopt);
int getEnumSetting(const QString &setting, const QStringList &valueNames) const;
QVariant getModelSetting(QLatin1StringView name, const ModelInfo &info) const;
QVariant getModelSetting(const char *name, const ModelInfo &info) const;
void setModelSetting(QLatin1StringView name, const ModelInfo &info, const QVariant &value, bool force,
bool signal = false);
void setModelSetting(const char *name, const ModelInfo &info, const QVariant &value, bool force,
bool signal = false);
auto getUpgradeableModelSetting(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
) const -> UpgradeableSetting;
bool isUpgradeableModelSettingSet(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
) const;
bool setUpgradeableModelSetting(
const ModelInfo &info, const QString &value, QLatin1StringView legacyKey, QLatin1StringView newKey
);
bool resetUpgradeableModelSetting(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
);
QString filePathForLocale(const QLocale &locale);
private:
QSettings m_settings;
bool m_forceMetal;
@ -253,18 +289,7 @@ private:
const QStringList m_uiLanguages;
std::unique_ptr<QTranslator> m_translator;
private:
explicit MySettings();
~MySettings() {}
friend class MyPrivateSettings;
QVariant getBasicSetting(const QString &name) const;
void setBasicSetting(const QString &name, const QVariant &value, std::optional<QString> signal = std::nullopt);
int getEnumSetting(const QString &setting, const QStringList &valueNames) const;
QVariant getModelSetting(const QString &name, const ModelInfo &info) const;
void setModelSetting(const QString &name, const ModelInfo &info, const QVariant &value, bool force,
bool signal = false);
QString filePathForLocale(const QLocale &locale);
};
#endif // MYSETTINGS_H

View File

@ -8,6 +8,7 @@
#include "localdocsmodel.h"
#include "modellist.h"
#include "mysettings.h"
#include "utils.h"
#include <gpt4all-backend/llmodel.h>
@ -192,11 +193,14 @@ bool Network::packageAndSendJson(const QString &ingestId, const QString &json)
return false;
}
auto *currentChat = ChatListModel::globalInstance()->currentChat();
Q_ASSERT(currentChat);
auto modelInfo = currentChat->modelInfo();
Q_ASSERT(doc.isObject());
Q_ASSERT(ChatListModel::globalInstance()->currentChat());
QJsonObject object = doc.object();
object.insert("source", "gpt4all-chat");
object.insert("agent_id", ChatListModel::globalInstance()->currentChat()->modelInfo().filename());
object.insert("agent_id", modelInfo.filename());
object.insert("submitter_id", m_uniqueId);
object.insert("ingest_id", ingestId);
@ -204,8 +208,9 @@ bool Network::packageAndSendJson(const QString &ingestId, const QString &json)
if (!attribution.isEmpty())
object.insert("network/attribution", attribution);
QString promptTemplate = ChatListModel::globalInstance()->currentChat()->modelInfo().promptTemplate();
object.insert("prompt_template", promptTemplate);
if (!modelInfo.id().isNull())
if (auto tmpl = modelInfo.chatTemplate().asModern())
object.insert("chat_template"_L1, *tmpl);
QJsonDocument newDoc;
newDoc.setObject(object);
@ -358,7 +363,8 @@ void Network::sendStartup()
void Network::trackChatEvent(const QString &ev, QVariantMap props)
{
const auto &curChat = ChatListModel::globalInstance()->currentChat();
auto *curChat = ChatListModel::globalInstance()->currentChat();
Q_ASSERT(curChat);
if (!props.contains("model"))
props.insert("model", curChat->modelInfo().filename());
props.insert("device_backend", curChat->deviceBackend());
@ -366,7 +372,7 @@ void Network::trackChatEvent(const QString &ev, QVariantMap props)
props.insert("doc_collections_enabled", curChat->collectionList().count());
props.insert("doc_collections_total", LocalDocs::globalInstance()->localDocsModel()->rowCount());
props.insert("datalake_active", MySettings::globalInstance()->networkIsActive());
props.insert("using_server", ChatListModel::globalInstance()->currentChat()->isServer());
props.insert("using_server", curChat->isServer());
trackEvent(ev, props);
}

View File

@ -313,11 +313,8 @@ const std::unordered_map<BaseCompletionRequest::Type, const char *> BaseCompleti
class ChatRequest : public BaseCompletionRequest {
public:
struct Message {
enum class Role : uint8_t {
User,
Assistant,
};
Role role;
enum class Role { System, User, Assistant };
Role role;
QString content;
};
@ -349,7 +346,6 @@ protected:
this->messages.clear();
{
QCborArray arr = value.toArray();
Message::Role nextRole = Message::Role::User;
for (qsizetype i = 0; i < arr.size(); i++) {
const auto &elem = arr[i];
if (!elem.isMap())
@ -360,9 +356,9 @@ protected:
QCborMap msg = elem.toMap();
Message res;
QString role = takeValue(msg, "role", String, /*required*/ true).toString();
if (role == u"system"_s)
continue; // FIXME(jared): don't ignore these
if (role == u"user"_s) {
if (role == u"system"_s) {
res.role = Message::Role::System;
} else if (role == u"user"_s) {
res.role = Message::Role::User;
} else if (role == u"assistant"_s) {
res.role = Message::Role::Assistant;
@ -374,13 +370,7 @@ protected:
));
}
res.content = takeValue(msg, "content", String, /*required*/ true).toString();
if (res.role != nextRole)
throw InvalidRequestError(fmt::format(
"Invalid 'messages[{}].role': did not expect '{}' here", i, role
));
this->messages.append(res);
nextRole = res.role == Message::Role::User ? Message::Role::Assistant
: Message::Role::User;
if (!msg.isEmpty())
throw InvalidRequestError(fmt::format(
@ -630,8 +620,7 @@ void Server::start()
});
#endif
connect(this, &Server::requestServerNewPromptResponsePair, m_chat,
&Chat::serverNewPromptResponsePair, Qt::BlockingQueuedConnection);
connect(this, &Server::requestResetResponseState, m_chat, &Chat::resetResponseState, Qt::BlockingQueuedConnection);
}
static auto makeError(auto &&...args) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
@ -642,6 +631,10 @@ static auto makeError(auto &&...args) -> std::pair<QHttpServerResponse, std::opt
auto Server::handleCompletionRequest(const CompletionRequest &request)
-> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
{
Q_ASSERT(m_chatModel);
auto *mySettings = MySettings::globalInstance();
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
for (const ModelInfo &info : modelList) {
@ -654,10 +647,6 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
}
}
// adds prompt/response items to GUI
emit requestServerNewPromptResponsePair(request.prompt); // blocks
resetResponse();
// load the new model if necessary
setShouldBeLoaded(true);
@ -666,47 +655,55 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
}
emit requestResetResponseState(); // blocks
qsizetype prevMsgIndex = m_chatModel->count() - 1;
if (prevMsgIndex >= 0)
m_chatModel->updateCurrentResponse(prevMsgIndex, false);
// NB: this resets the context, regardless of whether this model is already loaded
if (!loadModel(modelInfo)) {
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
}
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
const int top_k = modelInfo.topK();
const int n_batch = modelInfo.promptBatchSize();
const auto repeat_penalty = float(modelInfo.repeatPenalty());
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
// add prompt/response items to GUI
m_chatModel->appendPrompt(request.prompt);
m_chatModel->appendResponse(prevMsgIndex + 1);
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
LLModel::PromptContext promptCtx {
.n_predict = request.max_tokens,
.top_k = mySettings->modelTopK(modelInfo),
.top_p = request.top_p,
.min_p = request.min_p,
.temp = request.temperature,
.n_batch = mySettings->modelPromptBatchSize(modelInfo),
.repeat_penalty = float(mySettings->modelRepeatPenalty(modelInfo)),
.repeat_last_n = mySettings->modelRepeatPenaltyTokens(modelInfo),
};
auto promptUtf8 = request.prompt.toUtf8();
int promptTokens = 0;
int responseTokens = 0;
QList<QPair<QString, QList<ResultInfo>>> responses;
QStringList responses;
for (int i = 0; i < request.n; ++i) {
if (!promptInternal(
m_collections,
request.prompt,
/*promptTemplate*/ u"%1"_s,
request.max_tokens,
top_k,
request.top_p,
request.min_p,
request.temperature,
n_batch,
repeat_penalty,
repeat_last_n)) {
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
PromptResult result;
try {
result = promptInternal(std::string_view(promptUtf8.cbegin(), promptUtf8.cend()),
promptCtx,
/*usedLocalDocs*/ false);
} catch (const std::exception &e) {
emit responseChanged(e.what());
emit responseStopped(0);
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
}
QString resp = response(/*trim*/ false);
QString resp = QString::fromUtf8(result.response);
if (request.echo)
resp = request.prompt + resp;
responses.append({resp, m_databaseResults});
if (!promptTokens)
promptTokens = m_promptTokens;
responseTokens += m_promptResponseTokens - m_promptTokens;
if (i < request.n - 1)
resetResponse();
responses << resp;
if (i == 0)
promptTokens = result.promptTokens;
responseTokens += result.responseTokens;
}
QJsonObject responseObject {
@ -717,25 +714,13 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
};
QJsonArray choices;
{
int index = 0;
for (const auto &r : responses) {
QString result = r.first;
QList<ResultInfo> infos = r.second;
QJsonObject choice {
{ "text", result },
{ "index", index++ },
{ "logprobs", QJsonValue::Null },
{ "finish_reason", responseTokens == request.max_tokens ? "length" : "stop" },
};
if (MySettings::globalInstance()->localDocsShowReferences()) {
QJsonArray references;
for (const auto &ref : infos)
references.append(resultToJson(ref));
choice.insert("references", references.isEmpty() ? QJsonValue::Null : QJsonValue(references));
}
choices.append(choice);
}
for (qsizetype i = 0; auto &resp : std::as_const(responses)) {
choices << QJsonObject {
{ "text", resp },
{ "index", i++ },
{ "logprobs", QJsonValue::Null },
{ "finish_reason", responseTokens == request.max_tokens ? "length" : "stop" },
};
}
responseObject.insert("choices", choices);
@ -751,6 +736,8 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
auto Server::handleChatRequest(const ChatRequest &request)
-> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
{
auto *mySettings = MySettings::globalInstance();
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
for (const ModelInfo &info : modelList) {
@ -771,83 +758,58 @@ auto Server::handleChatRequest(const ChatRequest &request)
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
}
emit requestResetResponseState(); // blocks
// NB: this resets the context, regardless of whether this model is already loaded
if (!loadModel(modelInfo)) {
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
}
const QString promptTemplate = modelInfo.promptTemplate();
const int top_k = modelInfo.topK();
const int n_batch = modelInfo.promptBatchSize();
const auto repeat_penalty = float(modelInfo.repeatPenalty());
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
int promptTokens = 0;
Q_ASSERT(!request.messages.isEmpty());
// adds prompt/response items to GUI
std::vector<ChatItem> chatItems;
for (auto &message : request.messages) {
using enum ChatRequest::Message::Role;
switch (message.role) {
case System: chatItems.emplace_back(ChatItem::system_tag, message.content); break;
case User: chatItems.emplace_back(ChatItem::prompt_tag, message.content); break;
case Assistant: chatItems.emplace_back(ChatItem::response_tag, /*currentResponse*/ false); break;
}
}
m_chatModel->appendResponseWithHistory(chatItems);
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
LLModel::PromptContext promptCtx {
.n_predict = request.max_tokens,
.top_k = mySettings->modelTopK(modelInfo),
.top_p = request.top_p,
.min_p = request.min_p,
.temp = request.temperature,
.n_batch = mySettings->modelPromptBatchSize(modelInfo),
.repeat_penalty = float(mySettings->modelRepeatPenalty(modelInfo)),
.repeat_last_n = mySettings->modelRepeatPenaltyTokens(modelInfo),
};
int promptTokens = 0;
int responseTokens = 0;
QList<QPair<QString, QList<ResultInfo>>> responses;
Q_ASSERT(!request.messages.isEmpty());
Q_ASSERT(request.messages.size() % 2 == 1);
for (int i = 0; i < request.messages.size() - 2; i += 2) {
using enum ChatRequest::Message::Role;
auto &user = request.messages[i];
auto &assistant = request.messages[i + 1];
Q_ASSERT(user.role == User);
Q_ASSERT(assistant.role == Assistant);
// adds prompt/response items to GUI
emit requestServerNewPromptResponsePair(user.content); // blocks
resetResponse();
if (!promptInternal(
{},
user.content,
promptTemplate,
request.max_tokens,
top_k,
request.top_p,
request.min_p,
request.temperature,
n_batch,
repeat_penalty,
repeat_last_n,
assistant.content)
) {
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
}
promptTokens += m_promptResponseTokens; // previous responses are part of current prompt
}
QString lastMessage = request.messages.last().content;
// adds prompt/response items to GUI
emit requestServerNewPromptResponsePair(lastMessage); // blocks
resetResponse();
for (int i = 0; i < request.n; ++i) {
if (!promptInternal(
m_collections,
lastMessage,
promptTemplate,
request.max_tokens,
top_k,
request.top_p,
request.min_p,
request.temperature,
n_batch,
repeat_penalty,
repeat_last_n)
) {
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
ChatPromptResult result;
try {
result = promptInternalChat(m_collections, promptCtx);
} catch (const std::exception &e) {
emit responseChanged(e.what());
emit responseStopped(0);
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
}
responses.append({response(), m_databaseResults});
// FIXME(jared): these are UI counts and do not include framing tokens, which they should
responses.emplace_back(result.response, result.databaseResults);
if (i == 0)
promptTokens += m_promptTokens;
responseTokens += m_promptResponseTokens - m_promptTokens;
if (i != request.n - 1)
resetResponse();
promptTokens = result.promptTokens;
responseTokens += result.responseTokens;
}
QJsonObject responseObject {

View File

@ -33,7 +33,7 @@ public Q_SLOTS:
void start();
Q_SIGNALS:
void requestServerNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments = {});
void requestResetResponseState();
private:
auto handleCompletionRequest(const CompletionRequest &request) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>;

View File

@ -3,23 +3,41 @@
#include <fmt/base.h>
#include <fmt/format.h>
#include <QByteArray>
#include <QJsonValue>
#include <QLatin1StringView>
#include <QString>
#include <QStringView>
#include <QUtf8StringView>
#include <QVariant>
#include <string>
#include <initializer_list>
#include <string_view>
#include <utility>
class QJsonObject;
// fmtlib formatters for QString and QVariant
#define MAKE_FORMATTER(type, conversion) \
template <> \
struct fmt::formatter<type, char>: fmt::formatter<std::string, char> { \
template <typename FmtContext> \
FmtContext::iterator format(const type &value, FmtContext &ctx) const \
{ \
return formatter<std::string, char>::format(conversion, ctx); \
} \
#define MAKE_FORMATTER(type, conversion) \
template <> \
struct fmt::formatter<type, char>: fmt::formatter<std::string_view, char> { \
template <typename FmtContext> \
FmtContext::iterator format(const type &value, FmtContext &ctx) const \
{ \
auto valueUtf8 = (conversion); \
std::string_view view(valueUtf8.cbegin(), valueUtf8.cend()); \
return formatter<std::string_view, char>::format(view, ctx); \
} \
}
MAKE_FORMATTER(QString, value.toStdString() );
MAKE_FORMATTER(QVariant, value.toString().toStdString());
MAKE_FORMATTER(QUtf8StringView, value );
MAKE_FORMATTER(QStringView, value.toUtf8() );
MAKE_FORMATTER(QString, value.toUtf8() );
MAKE_FORMATTER(QVariant, value.toString().toUtf8());
// alternative to QJsonObject's initializer_list constructor that accepts Latin-1 strings
QJsonObject makeJsonObject(std::initializer_list<std::pair<QLatin1StringView, QJsonValue>> args);
#include "utils.inl"

View File

@ -0,0 +1,9 @@
#include <QJsonObject>
inline QJsonObject makeJsonObject(std::initializer_list<std::pair<QLatin1StringView, QJsonValue>> args)
{
QJsonObject obj;
for (auto &arg : args)
obj.insert(arg.first, arg.second);
return obj;
}

View File

@ -203,11 +203,10 @@ EXPECTED_MODEL_INFO = {
EXPECTED_COMPLETIONS_RESPONSE = {
'choices': [
{
'finish_reason': 'stop',
'finish_reason': 'length',
'index': 0,
'logprobs': None,
'references': None,
'text': ' jumps over the lazy dog.',
'text': ' jumps over the lazy dog.\n',
},
],
'id': 'placeholder',
@ -242,18 +241,14 @@ def test_with_models(chat_server_with_model: None) -> None:
'type': 'invalid_request_error',
}}
data = {
'model': 'Llama 3.2 1B Instruct',
'prompt': 'The quick brown fox',
'temperature': 0,
}
data = dict(
model = 'Llama 3.2 1B Instruct',
prompt = 'The quick brown fox',
temperature = 0,
max_tokens = 6,
)
response = request.post('completions', data=data)
assert len(response['choices']) == 1
assert response['choices'][0].keys() == {'text', 'index', 'logprobs', 'references', 'finish_reason'}
assert response['choices'][0]['text'] == ' jumps over the lazy dog.'
assert 'created' in response
response.pop('created') # Remove the dynamic field for comparison
del response['created'] # Remove the dynamic field for comparison
assert response == EXPECTED_COMPLETIONS_RESPONSE