Remove binary state from high-level API and use Jinja templates (#3147)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Signed-off-by: Adam Treat <treat.adam@gmail.com>
Co-authored-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Jared Van Bortel 2024-11-25 10:04:17 -05:00 committed by GitHub
parent 3320094d29
commit 225bf6be93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
54 changed files with 3423 additions and 2224 deletions

3
.gitmodules vendored
View File

@ -17,3 +17,6 @@
[submodule "gpt4all-chat/deps/QXlsx"] [submodule "gpt4all-chat/deps/QXlsx"]
path = gpt4all-chat/deps/QXlsx path = gpt4all-chat/deps/QXlsx
url = https://github.com/nomic-ai/QXlsx.git url = https://github.com/nomic-ai/QXlsx.git
[submodule "gpt4all-chat/deps/Jinja2Cpp"]
path = gpt4all-chat/deps/Jinja2Cpp
url = https://github.com/nomic-ai/jinja2cpp.git

View File

@ -5,6 +5,7 @@
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <expected>
#include <functional> #include <functional>
#include <optional> #include <optional>
#include <span> #include <span>
@ -24,6 +25,10 @@ using namespace std::string_literals;
class LLModel { class LLModel {
public: public:
using Token = int32_t; using Token = int32_t;
using PromptCallback = std::function<bool(std::span<const Token> batch, bool cached)>;
using ResponseCallback = std::function<bool(Token token, std::string_view piece)>;
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
using ProgressCallback = std::function<bool(float progress)>;
class BadArchError: public std::runtime_error { class BadArchError: public std::runtime_error {
public: public:
@ -101,6 +106,7 @@ public:
static int32_t maxContextLength(const std::string &modelPath); static int32_t maxContextLength(const std::string &modelPath);
static int32_t layerCount(const std::string &modelPath); static int32_t layerCount(const std::string &modelPath);
static bool isEmbeddingModel(const std::string &modelPath); static bool isEmbeddingModel(const std::string &modelPath);
static auto chatTemplate(const char *modelPath) -> std::expected<std::string, std::string>;
static void setImplementationsSearchPath(const std::string &path); static void setImplementationsSearchPath(const std::string &path);
static const std::string &implementationsSearchPath(); static const std::string &implementationsSearchPath();
static bool hasSupportedCPU(); static bool hasSupportedCPU();
@ -124,7 +130,6 @@ public:
}; };
struct PromptContext { struct PromptContext {
int32_t n_past = 0; // number of tokens in past conversation
int32_t n_predict = 200; int32_t n_predict = 200;
int32_t top_k = 40; int32_t top_k = 40;
float top_p = 0.9f; float top_p = 0.9f;
@ -136,8 +141,6 @@ public:
float contextErase = 0.5f; // percent of context to erase if we exceed the context window float contextErase = 0.5f; // percent of context to erase if we exceed the context window
}; };
using ProgressCallback = std::function<bool(float progress)>;
explicit LLModel() {} explicit LLModel() {}
virtual ~LLModel() {} virtual ~LLModel() {}
@ -154,16 +157,12 @@ public:
// This method requires the model to return true from supportsCompletion otherwise it will throw // This method requires the model to return true from supportsCompletion otherwise it will throw
// an error // an error
virtual void prompt(const std::string &prompt, virtual void prompt(std::string_view prompt,
const std::string &promptTemplate, const PromptCallback &promptCallback,
std::function<bool(int32_t)> promptCallback, const ResponseCallback &responseCallback,
std::function<bool(int32_t, const std::string&)> responseCallback, const PromptContext &ctx);
bool allowContextShift,
PromptContext &ctx,
bool special = false,
std::optional<std::string_view> fakeReply = {});
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend); virtual int32_t countPromptTokens(std::string_view prompt) const;
virtual size_t embeddingSize() const { virtual size_t embeddingSize() const {
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings"); throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
@ -209,23 +208,22 @@ public:
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; } void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }
virtual int32_t contextLength() const = 0; virtual int32_t contextLength() const = 0;
virtual auto specialTokens() -> std::unordered_map<std::string, std::string> const = 0;
protected: protected:
// These are pure virtual because subclasses need to implement as the default implementation of // These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions // 'prompt' above calls these functions
virtual std::vector<Token> tokenize(std::string_view str, bool special = false) = 0; virtual std::vector<Token> tokenize(std::string_view str) const = 0;
virtual bool isSpecialToken(Token id) const = 0; virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0; virtual std::string tokenToString(Token id) const = 0;
virtual void initSampler(PromptContext &ctx) = 0; virtual void initSampler(const PromptContext &ctx) = 0;
virtual Token sampleToken() const = 0; virtual Token sampleToken() const = 0;
virtual bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const = 0; virtual bool evalTokens(int32_t nPast, std::span<const Token> tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0; virtual void shiftContext(const PromptContext &promptCtx, int32_t *nPast) = 0;
virtual int32_t inputLength() const = 0; virtual int32_t inputLength() const = 0;
virtual void setTokenizeInputPosition(int32_t pos) = 0; virtual int32_t computeModelInputPosition(std::span<const Token> input) const = 0;
virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input) virtual void setModelInputPosition(int32_t pos) = 0;
-> std::vector<Token>::const_iterator = 0; virtual void appendInputToken(Token tok) = 0;
virtual void setModelInputPosition(PromptContext &ctx, int32_t pos) = 0;
virtual void appendInputToken(PromptContext &ctx, Token tok) = 0;
virtual std::span<const Token> inputTokens() const = 0; virtual std::span<const Token> inputTokens() const = 0;
virtual const std::vector<Token> &endTokens() const = 0; virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0; virtual bool shouldAddBOS() const = 0;
@ -242,6 +240,12 @@ protected:
return -1; return -1;
} }
virtual auto chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string>
{
(void)modelPath;
return std::unexpected("not implemented");
}
const Implementation *m_implementation = nullptr; const Implementation *m_implementation = nullptr;
ProgressCallback m_progressCallback; ProgressCallback m_progressCallback;
@ -253,19 +257,15 @@ protected:
return true; return true;
} }
bool decodePrompt(std::function<bool(int32_t)> promptCallback, // prefill context with prompt
std::function<bool(int32_t, const std::string&)> responseCallback, auto decodePrompt(const PromptCallback &promptCallback,
bool allowContextShift, const PromptContext &promptCtx,
PromptContext &promptCtx, std::vector<Token> embd_inp)
std::vector<Token> embd_inp, -> std::optional<int32_t>;
bool isResponse = false, // generate a response
bool alwaysDecode = false); void generateResponse(const ResponseCallback &responseCallback,
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback, const PromptContext &promptCtx,
bool allowContextShift, int32_t nPast);
PromptContext &promptCtx);
protected:
Token m_tokenize_last_token = -1; // not serialized
friend class LLMImplementation; friend class LLMImplementation;
}; };

View File

@ -35,16 +35,15 @@ typedef int32_t token_t;
* behavior. * behavior.
*/ */
struct llmodel_prompt_context { struct llmodel_prompt_context {
int32_t n_past; // number of tokens in past conversation
int32_t n_predict; // number of tokens to predict int32_t n_predict; // number of tokens to predict
int32_t top_k; // top k logits to sample from int32_t top_k; // top k logits to sample from
float top_p; // nucleus sampling probability threshold float top_p; // nucleus sampling probability threshold
float min_p; // Min P sampling float min_p; // Min P sampling
float temp; // temperature to adjust model's output distribution float temp; // temperature to adjust model's output distribution
int32_t n_batch; // number of predictions to generate in parallel int32_t n_batch; // number of predictions to generate in parallel
float repeat_penalty; // penalty factor for repeated tokens float repeat_penalty; // penalty factor for repeated tokens
int32_t repeat_last_n; // last n tokens to penalize int32_t repeat_last_n; // last n tokens to penalize
float context_erase; // percent of context to erase if we exceed the context window float context_erase; // percent of context to erase if we exceed the context window
}; };
struct llmodel_gpu_device { struct llmodel_gpu_device {
@ -63,10 +62,12 @@ typedef struct llmodel_gpu_device llmodel_gpu_device;
/** /**
* Callback type for prompt processing. * Callback type for prompt processing.
* @param token_id The token id of the prompt. * @param token_ids An array of token ids of the prompt.
* @param n_token_ids The number of tokens in the array.
* @param cached Whether the tokens were already in cache.
* @return a bool indicating whether the model should keep processing. * @return a bool indicating whether the model should keep processing.
*/ */
typedef bool (*llmodel_prompt_callback)(int32_t token_id); typedef bool (*llmodel_prompt_callback)(const token_t *token_ids, size_t n_token_ids, bool cached);
/** /**
* Callback type for response. * Callback type for response.
@ -74,7 +75,7 @@ typedef bool (*llmodel_prompt_callback)(int32_t token_id);
* @param response The response string. NOTE: a token_id of -1 indicates the string is an error string. * @param response The response string. NOTE: a token_id of -1 indicates the string is an error string.
* @return a bool indicating whether the model should keep generating. * @return a bool indicating whether the model should keep generating.
*/ */
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response); typedef bool (*llmodel_response_callback)(token_t token_id, const char *response);
/** /**
* Embedding cancellation callback for use with llmodel_embed. * Embedding cancellation callback for use with llmodel_embed.
@ -85,6 +86,8 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response
*/ */
typedef bool (*llmodel_emb_cancel_callback)(unsigned *batch_sizes, unsigned n_batch, const char *backend); typedef bool (*llmodel_emb_cancel_callback)(unsigned *batch_sizes, unsigned n_batch, const char *backend);
typedef void (*llmodel_special_token_callback)(const char *name, const char *token);
/** /**
* Create a llmodel instance. * Create a llmodel instance.
* Recognises correct model type from file at model_path * Recognises correct model type from file at model_path
@ -183,22 +186,17 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6
* Generate a response using the model. * Generate a response using the model.
* @param model A pointer to the llmodel_model instance. * @param model A pointer to the llmodel_model instance.
* @param prompt A string representing the input prompt. * @param prompt A string representing the input prompt.
* @param prompt_template A string representing the input prompt template.
* @param prompt_callback A callback function for handling the processing of prompt. * @param prompt_callback A callback function for handling the processing of prompt.
* @param response_callback A callback function for handling the generated response. * @param response_callback A callback function for handling the generated response.
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
* @param special True if special tokens in the prompt should be processed, false otherwise.
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
* @param ctx A pointer to the llmodel_prompt_context structure. * @param ctx A pointer to the llmodel_prompt_context structure.
* @param error A pointer to a string; will only be set on error.
*/ */
void llmodel_prompt(llmodel_model model, const char *prompt, bool llmodel_prompt(llmodel_model model,
const char *prompt_template, const char *prompt,
llmodel_prompt_callback prompt_callback, llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback, llmodel_response_callback response_callback,
bool allow_context_shift, llmodel_prompt_context *ctx,
llmodel_prompt_context *ctx, const char **error);
bool special,
const char *fake_reply);
/** /**
* Generate an embedding using the model. * Generate an embedding using the model.
@ -310,6 +308,10 @@ const char *llmodel_model_backend_name(llmodel_model model);
*/ */
const char *llmodel_model_gpu_device_name(llmodel_model model); const char *llmodel_model_gpu_device_name(llmodel_model model);
int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error);
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -202,7 +202,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const
if (keyidx != -1) { if (keyidx != -1) {
value = gguf_get_val_u32(ctx, keyidx); value = gguf_get_val_u32(ctx, keyidx);
} else { } else {
std::cerr << __func__ << ": " << key << "not found in " << modelPath << "\n"; std::cerr << __func__ << ": " << key << " not found in " << modelPath << "\n";
} }
} }
@ -518,18 +518,13 @@ size_t LLamaModel::restoreState(std::span<const uint8_t> state, std::span<const
return bytesRead; return bytesRead;
} }
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str, bool special) std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str) const
{ {
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
std::vector<LLModel::Token> fres(str.length() + 4); std::vector<LLModel::Token> fres(str.length() + 4);
int32_t fres_len = llama_tokenize_gpt4all( int32_t fres_len = llama_tokenize(
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart, d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ true, /*parse_special*/ true
/*parse_special*/ special, /*insert_space*/ insertSpace
); );
fres.resize(fres_len); fres.resize(fres_len);
if (fres_len)
m_tokenize_last_token = fres.back();
return fres; return fres;
} }
@ -555,7 +550,7 @@ std::string LLamaModel::tokenToString(Token id) const
return std::string(result.data(), result.size()); return std::string(result.data(), result.size());
} }
void LLamaModel::initSampler(PromptContext &promptCtx) void LLamaModel::initSampler(const PromptContext &promptCtx)
{ {
auto *model = d_ptr->model; auto *model = d_ptr->model;
auto *chain = d_ptr->sampler_chain; auto *chain = d_ptr->sampler_chain;
@ -601,9 +596,11 @@ LLModel::Token LLamaModel::sampleToken() const
return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1); return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1);
} }
bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) const bool LLamaModel::evalTokens(int32_t nPast, std::span<const Token> tokens) const
{ {
llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1); assert(!tokens.empty());
llama_kv_cache_seq_rm(d_ptr->ctx, 0, nPast, -1);
llama_batch batch = llama_batch_init(tokens.size(), 0, 1); llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
@ -611,7 +608,7 @@ bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) c
for (int32_t i = 0; i < batch.n_tokens; i++) { for (int32_t i = 0; i < batch.n_tokens; i++) {
batch.token [i] = tokens[i]; batch.token [i] = tokens[i];
batch.pos [i] = ctx.n_past + i; batch.pos [i] = nPast + i;
batch.n_seq_id[i] = 1; batch.n_seq_id[i] = 1;
batch.seq_id [i][0] = 0; batch.seq_id [i][0] = 0;
batch.logits [i] = false; batch.logits [i] = false;
@ -625,13 +622,13 @@ bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) c
return res == 0; return res == 0;
} }
void LLamaModel::shiftContext(PromptContext &promptCtx) void LLamaModel::shiftContext(const PromptContext &promptCtx, int32_t *nPast)
{ {
// infinite text generation via context shifting // infinite text generation via context shifting
// erase up to n_ctx*contextErase tokens // erase up to n_ctx*contextErase tokens
int n_keep = shouldAddBOS(); int n_keep = shouldAddBOS();
int n_past = promptCtx.n_past; int n_past = *nPast;
int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase)); int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase));
assert(n_discard > 0); assert(n_discard > 0);
@ -647,7 +644,7 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
auto &inp = d_ptr->inputTokens; auto &inp = d_ptr->inputTokens;
inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard); inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard);
promptCtx.n_past = inp.size(); *nPast = inp.size();
} }
int32_t LLamaModel::contextLength() const int32_t LLamaModel::contextLength() const
@ -655,39 +652,37 @@ int32_t LLamaModel::contextLength() const
return llama_n_ctx(d_ptr->ctx); return llama_n_ctx(d_ptr->ctx);
} }
auto LLamaModel::specialTokens() -> std::unordered_map<std::string, std::string> const
{
if (!d_ptr->model)
throw std::logic_error("model not loaded");
std::unordered_map<std::string, std::string> tokens;
if (auto id = llama_token_bos(d_ptr->model); id != LLAMA_TOKEN_NULL)
tokens.emplace("bos_token", tokenToString(id));
if (auto id = llama_token_eos(d_ptr->model); id != LLAMA_TOKEN_NULL)
tokens.emplace("eos_token", tokenToString(id));
return tokens;
}
int32_t LLamaModel::inputLength() const int32_t LLamaModel::inputLength() const
{ {
return d_ptr->inputTokens.size(); return d_ptr->inputTokens.size();
} }
void LLamaModel::setTokenizeInputPosition(int32_t pos) int32_t LLamaModel::computeModelInputPosition(std::span<const Token> input) const
{ {
assert(pos >= 0);
m_tokenize_last_token = pos ? d_ptr->inputTokens.at(size_t(pos) - 1) : -1; // not serialized
}
auto LLamaModel::computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator
{
assert(ctx.n_past >= 0);
auto pos = size_t(ctx.n_past);
if (pos > d_ptr->inputTokens.size()) {
std::ostringstream ss;
ss << "n_past=" << pos << " is past end of token cache length=" << d_ptr->inputTokens.size();
throw std::out_of_range(ss.str());
}
// find common prefix // find common prefix
auto cacheIt = d_ptr->inputTokens.begin(); auto cacheIt = d_ptr->inputTokens.begin();
auto inputIt = input.begin(); auto inputIt = input.begin();
while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) { while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) {
++cacheIt; ++inputIt; ++pos; ++cacheIt; ++inputIt;
} }
// tell the caller to ignore the tokens between [begin, inputIt) // tell the caller to ignore the tokens between [begin, inputIt)
return inputIt; return inputIt - input.begin();
} }
void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos) void LLamaModel::setModelInputPosition(int32_t pos)
{ {
auto &inp = d_ptr->inputTokens; auto &inp = d_ptr->inputTokens;
assert(pos >= 0); assert(pos >= 0);
@ -695,13 +690,11 @@ void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
// truncate token cache to end at the new n_past // truncate token cache to end at the new n_past
if (pos < inp.size()) if (pos < inp.size())
inp.resize(pos); inp.resize(pos);
ctx.n_past = pos;
} }
void LLamaModel::appendInputToken(PromptContext &ctx, Token tok) void LLamaModel::appendInputToken(Token tok)
{ {
d_ptr->inputTokens.push_back(tok); d_ptr->inputTokens.push_back(tok);
ctx.n_past += 1;
} }
auto LLamaModel::inputTokens() const -> std::span<const Token> auto LLamaModel::inputTokens() const -> std::span<const Token>
@ -729,6 +722,37 @@ int32_t LLamaModel::layerCount(std::string const &modelPath) const
return get_arch_key_u32(modelPath, "block_count"); return get_arch_key_u32(modelPath, "block_count");
} }
// TODO(jared): reduce redundant code and operations by combining all metadata getters for unloaded
// models into a class that keeps the model file open
auto LLamaModel::chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string>
{
auto *ctx = load_gguf(modelPath);
if (!ctx)
return std::unexpected("failed to open model file");
std::expected<std::string, std::string> result;
enum gguf_type ktype;
const int kid = gguf_find_key(ctx, "tokenizer.chat_template");
if (kid == -1) {
result = std::unexpected("key not found");
goto cleanup;
}
ktype = gguf_get_kv_type(ctx, kid);
if (ktype != GGUF_TYPE_STRING) {
result = std::unexpected(
"expected key type STRING (" + std::to_string(GGUF_TYPE_STRING) + "), got " + std::to_string(ktype)
);
goto cleanup;
}
result = gguf_get_val_str(ctx, kid);
cleanup:
gguf_free(ctx);
return result;
}
#ifdef GGML_USE_VULKAN #ifdef GGML_USE_VULKAN
static const char *getVulkanVendorName(uint32_t vendorID) static const char *getVulkanVendorName(uint32_t vendorID)
{ {

View File

@ -11,6 +11,7 @@
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <vector> #include <vector>
#include <unordered_map>
struct LLamaPrivate; struct LLamaPrivate;
struct EmbModelSpec; struct EmbModelSpec;
@ -49,26 +50,26 @@ public:
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override; size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
int32_t contextLength() const override; int32_t contextLength() const override;
auto specialTokens() -> std::unordered_map<std::string, std::string> const override;
protected: protected:
std::vector<Token> tokenize(std::string_view str, bool special) override; std::vector<Token> tokenize(std::string_view str) const override;
bool isSpecialToken(Token id) const override; bool isSpecialToken(Token id) const override;
std::string tokenToString(Token id) const override; std::string tokenToString(Token id) const override;
void initSampler(PromptContext &ctx) override; void initSampler(const PromptContext &ctx) override;
Token sampleToken() const override; Token sampleToken() const override;
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override; bool evalTokens(int32_t nPast, std::span<const Token> tokens) const override;
void shiftContext(PromptContext &promptCtx) override; void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override;
int32_t inputLength() const override; int32_t inputLength() const override;
void setTokenizeInputPosition(int32_t pos) override; int32_t computeModelInputPosition(std::span<const Token> input) const override;
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input) void setModelInputPosition(int32_t pos) override;
-> std::vector<Token>::const_iterator override; void appendInputToken(Token tok) override;
void setModelInputPosition(PromptContext &ctx, int32_t pos) override;
void appendInputToken(PromptContext &ctx, Token tok) override;
std::span<const Token> inputTokens() const override; std::span<const Token> inputTokens() const override;
const std::vector<Token> &endTokens() const override; const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override; bool shouldAddBOS() const override;
int32_t maxContextLength(std::string const &modelPath) const override; int32_t maxContextLength(std::string const &modelPath) const override;
int32_t layerCount(std::string const &modelPath) const override; int32_t layerCount(std::string const &modelPath) const override;
auto chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string> override;
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality, void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb, size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,

View File

@ -326,6 +326,12 @@ bool LLModel::Implementation::isEmbeddingModel(const std::string &modelPath)
return llama && llama->isEmbeddingModel(modelPath); return llama && llama->isEmbeddingModel(modelPath);
} }
auto LLModel::Implementation::chatTemplate(const char *modelPath) -> std::expected<std::string, std::string>
{
auto *llama = constructGlobalLlama();
return llama ? llama->chatTemplate(modelPath) : std::unexpected("backend not available");
}
void LLModel::Implementation::setImplementationsSearchPath(const std::string& path) void LLModel::Implementation::setImplementationsSearchPath(const std::string& path)
{ {
s_implementations_search_path = path; s_implementations_search_path = path;

View File

@ -7,7 +7,6 @@
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include <exception> #include <exception>
#include <functional>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <optional> #include <optional>
@ -22,7 +21,6 @@ static_assert(sizeof(token_t) == sizeof(LLModel::Token));
struct LLModelWrapper { struct LLModelWrapper {
LLModel *llModel = nullptr; LLModel *llModel = nullptr;
LLModel::PromptContext promptContext;
~LLModelWrapper() { delete llModel; } ~LLModelWrapper() { delete llModel; }
}; };
@ -126,49 +124,44 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6
return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)}); return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)});
} }
void llmodel_prompt(llmodel_model model, const char *prompt, bool llmodel_prompt(llmodel_model model,
const char *prompt_template, const char *prompt,
llmodel_prompt_callback prompt_callback, llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback, llmodel_response_callback response_callback,
bool allow_context_shift, llmodel_prompt_context *ctx,
llmodel_prompt_context *ctx, const char **error)
bool special,
const char *fake_reply)
{ {
auto *wrapper = static_cast<LLModelWrapper *>(model); auto *wrapper = static_cast<LLModelWrapper *>(model);
auto response_func = [response_callback](int32_t token_id, const std::string &response) { // Copy the C prompt context
return response_callback(token_id, response.c_str()); LLModel::PromptContext promptContext {
.n_predict = ctx->n_predict,
.top_k = ctx->top_k,
.top_p = ctx->top_p,
.min_p = ctx->min_p,
.temp = ctx->temp,
.n_batch = ctx->n_batch,
.repeat_penalty = ctx->repeat_penalty,
.repeat_last_n = ctx->repeat_last_n,
.contextErase = ctx->context_erase,
}; };
// Copy the C prompt context auto prompt_func = [prompt_callback](std::span<const LLModel::Token> token_ids, bool cached) {
wrapper->promptContext.n_past = ctx->n_past; return prompt_callback(token_ids.data(), token_ids.size(), cached);
wrapper->promptContext.n_predict = ctx->n_predict; };
wrapper->promptContext.top_k = ctx->top_k; auto response_func = [response_callback](LLModel::Token token_id, std::string_view piece) {
wrapper->promptContext.top_p = ctx->top_p; return response_callback(token_id, piece.data());
wrapper->promptContext.min_p = ctx->min_p; };
wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
// Call the C++ prompt method // Call the C++ prompt method
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift, try {
wrapper->promptContext, special, wrapper->llModel->prompt(prompt, prompt_func, response_func, promptContext);
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt); } catch (std::exception const &e) {
llmodel_set_error(error, e.what());
return false;
}
// Update the rest of the C prompt context return true;
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
ctx->min_p = wrapper->promptContext.min_p;
ctx->temp = wrapper->promptContext.temp;
ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
ctx->context_erase = wrapper->promptContext.contextErase;
} }
float *llmodel_embed( float *llmodel_embed(
@ -307,3 +300,21 @@ const char *llmodel_model_gpu_device_name(llmodel_model model)
const auto *wrapper = static_cast<LLModelWrapper *>(model); const auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->gpuDeviceName(); return wrapper->llModel->gpuDeviceName();
} }
int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error)
{
auto *wrapper = static_cast<const LLModelWrapper *>(model);
try {
return wrapper->llModel->countPromptTokens(prompt);
} catch (const std::exception& e) {
llmodel_set_error(error, e.what());
return -1;
}
}
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback)
{
auto *wrapper = static_cast<const LLModelWrapper *>(model);
for (auto &[name, token] : wrapper->llModel->specialTokens())
callback(name.c_str(), token.c_str());
}

View File

@ -4,232 +4,120 @@
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <functional>
#include <iostream> #include <iostream>
#include <iterator> #include <iterator>
#include <optional> #include <optional>
#include <regex> #include <ranges>
#include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <vector> #include <vector>
namespace ranges = std::ranges; namespace ranges = std::ranges;
namespace views = std::ranges::views;
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err) void LLModel::prompt(
{ std::string_view prompt,
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))"); const PromptCallback &promptCallback,
const ResponseCallback &responseCallback,
const PromptContext &promptCtx
) {
if (!isModelLoaded())
throw std::invalid_argument("Attempted to prompt an unloaded model.");
if (!supportsCompletion())
throw std::invalid_argument("Not a text completion model.");
if (!promptCtx.n_batch)
throw std::invalid_argument("Batch size cannot be zero.");
if (!promptCtx.n_predict)
return; // nothing requested
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex); auto embd_inp = tokenize(prompt);
placeholders.clear(); if (embd_inp.empty())
placeholders.insert(placeholders.end(), it, std::sregex_iterator()); throw std::invalid_argument("Prompt tokenized to zero tokens.");
if (placeholders.size() > 2) { if (auto res = decodePrompt(promptCallback, promptCtx, std::move(embd_inp)))
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size()); generateResponse(responseCallback, promptCtx, /*n_past*/ *res);
return false;
}
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
return false;
}
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
return false;
}
return true;
} }
void LLModel::prompt(const std::string &prompt, int32_t LLModel::countPromptTokens(std::string_view prompt) const
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::optional<std::string_view> fakeReply)
{ {
if (!isModelLoaded()) { if (!isModelLoaded())
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n"; throw std::invalid_argument("Attempted to tokenize with an unloaded model.");
return; return int32_t(tokenize(prompt).size());
}
if (!supportsCompletion()) {
std::string errorMessage = "ERROR: this model does not support text completion or chat!";
responseCallback(-1, errorMessage);
std::cerr << implementation().modelType() << " " << errorMessage << "\n";
return;
}
// sanity checks
if (promptCtx.n_past > contextLength()) {
std::ostringstream ss;
ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength();
throw std::out_of_range(ss.str());
}
if (promptCtx.n_past > inputLength()) {
std::ostringstream ss;
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << inputLength();
throw std::out_of_range(ss.str());
}
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
// parse the prompt template
std::vector<std::smatch> placeholders;
{
std::string err;
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
responseCallback(-1, err);
std::cerr << err << "\n";
return;
}
}
setTokenizeInputPosition(promptCtx.n_past);
// tokenize the user prompt
std::vector<Token> embd_inp;
if (placeholders.empty()) {
// this is unusual, but well-defined
std::cerr << __func__ << ": prompt template has no placeholder\n";
embd_inp = tokenize(promptTemplate, true);
} else {
// template: beginning of user prompt
const auto &phUser = placeholders[0];
std::string userPrefix(phUser.prefix());
if (!userPrefix.empty())
embd_inp = tokenize(userPrefix, true);
// user input (shouldn't have special token processing)
auto tokens = tokenize(prompt, special);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
// template: end of user prompt + start of assistant prompt
size_t start = phUser.position() + phUser.length();
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
auto userToAsst = promptTemplate.substr(start, end - start);
if (!userToAsst.empty()) {
tokens = tokenize(userToAsst, true);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
}
}
// decode the user prompt
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, /*isResponse*/ false,
/*alwaysDecode*/ true))
return; // error
// decode the assistant's reply, either generated or spoofed
if (!fakeReply) {
generateResponse(responseCallback, allowContextShift, promptCtx);
} else {
embd_inp = tokenize(*fakeReply, false);
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true))
return; // error
}
// decode the rest of the prompt template
// template: end of assistant prompt
std::string asstSuffix;
if (placeholders.size() >= 2) {
size_t start = placeholders[1].position() + placeholders[1].length();
asstSuffix = promptTemplate.substr(start);
} else {
asstSuffix = "\n\n"; // default to a blank link, good for e.g. Alpaca
}
if (!asstSuffix.empty()) {
embd_inp = tokenize(asstSuffix, true);
decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
}
} }
// returns false on error auto LLModel::decodePrompt(
bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback, const PromptCallback &promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback, const PromptContext &promptCtx,
bool allowContextShift, std::vector<Token> embd_inp
PromptContext &promptCtx, ) -> std::optional<int32_t>
std::vector<Token> embd_inp, {
bool isResponse, assert(!embd_inp.empty());
bool alwaysDecode) {
if ((int) embd_inp.size() > contextLength() - 4) {
// FIXME: (Adam) We should find a way to bubble these strings to the UI level to allow for
// translation
responseCallback(-1, "Your message was too long and could not be processed. Please try again with something shorter.");
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
" tokens and the context window is " << contextLength() << "!\n";
return false;
}
// FIXME(jared): There are mitigations for this situation, such as making room before int32_t nCtx = contextLength();
// copying the prompt context, or restoring the KV cache when we restore the prompt int32_t n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
// context.
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size()
<< ", n_ctx=" << contextLength() << "\n";
return false;
}
// always decode something before generating, even if cached
if (alwaysDecode && embd_inp.empty()) {
auto cache = inputTokens();
if (!promptCtx.n_past)
throw std::runtime_error("zero token prompt is not supported");
assert(!cache.empty());
embd_inp.push_back(cache.back());
promptCtx.n_past--;
}
// Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the // Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the
// requested n_past. // requested n_past.
// This is used to skip unnecessary work when the prompt shares a common prefix with the previous result. // This is used to skip unnecessary work when the prompt shares a common prefix with the previous result.
auto embd_inp_start = computeModelInputPosition(promptCtx, embd_inp); int32_t nPast = computeModelInputPosition(embd_inp);
size_t start_offset = embd_inp_start - embd_inp.begin();
// always decode up to a full batch before generating, even if cached // always decode up to a full batch before generating, even if cached
if (alwaysDecode) nPast -= std::min(n_batch, nPast);
start_offset -= std::min(promptCtx.n_batch, int32_t(start_offset));
setModelInputPosition(promptCtx, promptCtx.n_past + start_offset); // TODO(jared): generalize this to find the smallest new_embd_inp.size() - nPast given the cache
if (!nPast && int32_t(embd_inp.size()) > nCtx) {
// no cache hit -> shift the input before even processing
// execute the callback even for skipped tokens int32_t nKeep = shouldAddBOS();
size_t i = 0; auto newLength = int32_t(nCtx * (1.f - promptCtx.contextErase));
for (; i < start_offset; i++) { int32_t nDiscard = int32_t(embd_inp.size()) - std::max(1, std::min(nCtx, newLength));
Token tok = embd_inp[i];
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok); // execute the callback even for skipped tokens. this misrepresents the position of BOS but we don't care
if (!res) auto discardedTokens = embd_inp | views::drop(nKeep) | views::take(nDiscard);
return false; if (!promptCallback(discardedTokens, true))
return std::nullopt;
// erase nDiscard tokens
embd_inp.erase(discardedTokens.begin(), discardedTokens.end());
assert(int32_t(embd_inp.size()) <= nCtx);
// check the cache again, just in case
nPast = computeModelInputPosition(embd_inp);
nPast -= std::min(n_batch, nPast);
} }
setModelInputPosition(nPast);
// execute the callback even for skipped tokens
if (!promptCallback(embd_inp | views::take(nPast), true))
return std::nullopt;
// process the prompt in batches // process the prompt in batches
while (i < embd_inp.size()) { for (int32_t i = nPast; i < embd_inp.size();) {
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); auto batch_end = std::min(i + n_batch, int32_t(embd_inp.size()));
std::span<const Token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); std::span batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
// Check if the context has run out... // Check if the context has run out...
if (promptCtx.n_past + int32_t(batch.size()) > contextLength()) { if (nPast + int32_t(batch.size()) > nCtx) {
assert(allowContextShift); shiftContext(promptCtx, &nPast);
shiftContext(promptCtx); assert(nPast + int32_t(batch.size()) <= nCtx);
assert(promptCtx.n_past + int32_t(batch.size()) <= contextLength());
} }
if (!evalTokens(promptCtx, batch)) { // FIXME(Adam): We should find a way to bubble these strings to the UI level to allow for translation
std::cerr << implementation().modelType() << " ERROR: Failed to process prompt\n"; if (!evalTokens(nPast, batch))
return false; throw std::runtime_error("An internal error was encountered during prompt processing.");
}
size_t tokens = batch_end - i; for (auto &tok : batch) {
for (size_t t = 0; t < tokens; ++t) { appendInputToken(tok);
Token tok = batch[t]; nPast++;
appendInputToken(promptCtx, tok); if (!promptCallback({ &tok, 1 }, false))
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok); return std::nullopt;
if (!res)
return false;
} }
i = batch_end; i = batch_end;
} }
return true; return nPast;
} }
/* /*
@ -251,22 +139,16 @@ static std::string::size_type stringsOverlap(const std::string &s, const std::st
return std::string::npos; return std::string::npos;
} }
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback, void LLModel::generateResponse(
bool allowContextShift, const ResponseCallback &responseCallback,
PromptContext &promptCtx) { const PromptContext &promptCtx,
int32_t nPast
) {
static const char *stopSequences[] { static const char *stopSequences[] {
"### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context", "### System", "### Instruction", "### Human", "### User", "### Response", "### Assistant", "### Context",
"<|im_start|>", "<|im_end|>", "<|endoftext|>",
}; };
// Don't even start if there is no room
if (!promptCtx.n_predict)
return;
if (!allowContextShift && promptCtx.n_past >= contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << contextLength()
<< "\n";
return;
}
initSampler(promptCtx); initSampler(promptCtx);
std::string cachedResponse; std::string cachedResponse;
@ -281,25 +163,20 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
cachedTokens.push_back(new_tok.value()); cachedTokens.push_back(new_tok.value());
cachedResponse += new_piece; cachedResponse += new_piece;
auto accept = [this, &promptCtx, &new_tok, allowContextShift]() -> bool { auto accept = [this, &promptCtx, &new_tok, &nPast] {
// Shift context if out of space // Shift context if out of space
if (promptCtx.n_past >= contextLength()) { if (nPast >= contextLength()) {
(void)allowContextShift; shiftContext(promptCtx, &nPast);
assert(allowContextShift); assert(nPast < contextLength());
shiftContext(promptCtx);
assert(promptCtx.n_past < contextLength());
} }
// Accept the token // Accept the token
Token tok = std::exchange(new_tok, std::nullopt).value(); Token tok = std::exchange(new_tok, std::nullopt).value();
if (!evalTokens(promptCtx, { &tok, 1 })) { if (!evalTokens(nPast, { &tok, 1 }))
// TODO(jared): raise an exception throw std::runtime_error("An internal error was encountered during response generation.");
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
return false;
}
appendInputToken(promptCtx, tok); appendInputToken(tok);
return true; nPast++;
}; };
// Check for EOS // Check for EOS
@ -336,13 +213,6 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
lengthLimit = cachedResponse.size() - new_piece.size(); lengthLimit = cachedResponse.size() - new_piece.size();
} }
// Optionally stop if the context will run out
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx="
<< contextLength() << "\n";
stop = true;
}
// Empty the cache, up to the length limit // Empty the cache, up to the length limit
std::string::size_type responseLength = 0; std::string::size_type responseLength = 0;
while (!cachedTokens.empty()) { while (!cachedTokens.empty()) {
@ -359,8 +229,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
cachedResponse.erase(cachedResponse.begin(), cachedResponse.begin() + piece.size()); cachedResponse.erase(cachedResponse.begin(), cachedResponse.begin() + piece.size());
// Accept the token, if needed (not cached) // Accept the token, if needed (not cached)
if (cachedTokens.empty() && new_tok && !accept()) if (cachedTokens.empty() && new_tok)
return; accept();
// Send the token // Send the token
if (!responseCallback(tok, piece) || ++n_predicted >= promptCtx.n_predict) { if (!responseCallback(tok, piece) || ++n_predicted >= promptCtx.n_predict) {
@ -379,8 +249,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
assert(!cachedTokens.empty() && cachedTokens.back() == new_tok); assert(!cachedTokens.empty() && cachedTokens.back() == new_tok);
if (stop) { if (stop) {
cachedTokens.pop_back(); cachedTokens.pop_back();
} else if (!accept()) { } else {
return; accept();
} }
} }
} }
@ -396,8 +266,6 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
auto discard_start = inp.end() - cachedTokens.size(); auto discard_start = inp.end() - cachedTokens.size();
assert(std::equal(discard_start, inp.end(), cachedTokens.begin())); assert(std::equal(discard_start, inp.end(), cachedTokens.begin()));
#endif #endif
promptCtx.n_past -= cachedTokens.size();
} }
void LLModel::embed( void LLModel::embed(

View File

@ -0,0 +1,206 @@
## What are chat templates?
Natively, large language models only know how to complete plain text and do not know the difference between their input and their output. In order to support a chat with a person, LLMs are designed to use a template to convert the conversation to plain text using a specific format.
For a given model, it is important to use an appropriate chat template, as each model is designed to work best with a specific format. The chat templates included with the built-in models should be sufficient for most purposes.
There are two reasons you would want to alter the chat template:
- You are sideloading a model and there is no chat template available,
- You would like to have greater control over the input to the LLM than a system message provides.
## What is a system message?
A system message is a message that controls the responses from the LLM in a way that affects the entire conversation. System messages can be short, such as "Speak like a pirate.", or they can be long and contain a lot of context for the LLM to keep in mind.
Not all models are designed to use a system message, so they work with some models better than others.
## How do I customize the chat template or system message?
To customize the chat template or system message, go to Settings > Model. Make sure to select the correct model at the top. If you clone a model, you can use a different chat template or system message from the base model, enabling you to use different settings for each conversation.
These settings take effect immediately. After changing them, you can click "Redo last response" in the chat view, and the response will take the new settings into account.
## Do I need to write a chat template?
You typically do not need to write your own chat template. The exception is models that are not in the official model list and do not come with a chat template built-in. These will show a "Clear" option above the chat template field in the Model Settings page instead of a "Reset" option. See the section on [finding] or [creating] a chat template.
[finding]: #how-do-i-find-a-chat-template
[creating]: #advanced-how-do-chat-templates-work
## What changed in GPT4All v3.5?
GPT4All v3.5 overhauled the chat template system. There are three crucial differences:
- The chat template now formats an entire conversation instead of a single pair of messages,
- The chat template now uses Jinja syntax instead of `%1` and `%2` placeholders,
- And the system message should no longer contain control tokens or trailing whitespace.
If you are using any chat templates or system messages that had been added or altered from the default before upgrading to GPT4All v3.5 or newer, these will no longer work. See below for how to solve common errors you may see after upgrading.
## Error/Warning: System message is not plain text.
This is easy to fix. Go to the model's settings and look at the system prompt. There are three things to look for:
- Control tokens such as `<|im_start|>`, `<|start_header_id|>`, or `<|system|>`
- A prefix such as `### System` or `SYSTEM:`
- Trailing whitespace, such as a space character or blank line.
If you see any of these things, remove them. For example, this legacy system prompt:
```
<|start_header_id|>system<|end_header_id|>
You are a helpful assistant.<|eot_id|>
```
Should become this:
```
You are a helpful assistant.
```
If you do not see anything that needs to be changed, you can dismiss the error by making a minor modification to the message and then changing it back.
If you see a warning, your system message does not appear to be plain text. If you believe this warning is incorrect, it can be safely ignored. If in doubt, ask on the [Discord].
[Discord]: https://discord.gg/mGZE39AS3e
## Error: Legacy system prompt needs to be updated in Settings.
This is the same as [above][above-1], but appears on the chat page.
[above-1]: #errorwarning-system-message-is-not-plain-text
## Error/Warning: Chat template is not in Jinja format.
This is the result of attempting to use an old-style template (possibly from a previous version) in GPT4All 3.5+.
Go to the Model Settings page and select the affected model. If you see a "Reset" button, and you have not intentionally modified the prompt template, you can click "Reset". Otherwise, this is what you can do:
1. Back up your chat template by copying it safely to a text file and saving it. In the next step, it will be removed from GPT4All.
2. Click "Reset" or "Clear".
3. If you clicked "Clear", the chat template is now gone. Follow the steps to [find][finding] or [create][creating] a basic chat template for your model.
4. Customize the chat template to suit your needs. For help, read the section about [creating] a chat template.
## Error: Legacy prompt template needs to be updated in Settings.
This is the same as [above][above-2], but appears on the chat page.
[above-2]: #errorwarning-chat-template-is-not-in-jinja-format
## The chat template has a syntax error.
If there is a syntax error while editing the chat template, the details will be displayed in an error message above the input box. This could be because the chat template is not actually in Jinja format (see [above][above-2]).
Otherwise, you have either typed something correctly, or the model comes with a template that is incompatible with GPT4All. See [the below section][creating] on creating chat templates and make sure that everything is correct. When in doubt, ask on the [Discord].
## Error: No chat template configured.
This may appear for models that are not from the official model list and do not include a chat template. Older versions of GPT4All picked a poor default in this case. You will get much better results if you follow the steps to [find][finding] or [create][creating] a chat template for your model.
## Error: The chat template cannot be blank.
If the button above the chat template on the Model Settings page says "Clear", see [above][above-3]. If you see "Reset", click that button to restore a reasonable default. Also see the section on [syntax errors][chat-syntax-error].
[above-3]: #error-no-chat-template-configured
[chat-syntax-error]: #the-chat-template-has-a-syntax-error
## How do I find a chat template?
When in doubt, you can always ask the [Discord] community for help. Below are the instructions to find one on your own.
The authoritative source for a model's chat template is the HuggingFace repo that the original (non-GGUF) model came from. First, you should find this page. If you just have a model file, you can try a google search for the model's name. If you know the page you downloaded the GGUF model from, its README usually links to the original non-GGUF model.
Once you have located the original model, there are two methods you can use to extract its chat template. Pick whichever one you are most comfortable with.
### Using the CLI (all models)
1. Install `jq` using your preferred package manager - e.g. Chocolatey (Windows), Homebrew (macOS), or apt (Ubuntu).
2. Download `tokenizer_config.json` from the model's "Files and versions" tab.
3. Open a command prompt in the directory which you have downloaded the model file.
4. Run `jq -r ".chat_template" tokenizer_config.json`. This shows the chat template in a human-readable form. You can copy this and paste it into the settings page.
5. (Optional) You can save the output to a text file like this: `jq -r ".chat_template" tokenizer_config.json >chat_template.txt`
If the output is "null", the model does not provide a chat template. See the [below instructions][creating] on creating a chat template.
### Python (open models)
1. Install `transformers` using your preferred python package manager, e.g. `pip install transformers`. Make sure it is at least version v4.43.0.
2. Copy the ID of the HuggingFace model, using the clipboard icon next to the name. For example, if the URL is `https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B`, the ID is `NousResearch/Hermes-2-Pro-Llama-3-8B`.
3. Open a python interpreter (`python`) and run the following commands. Change the model ID in the example to the one you copied.
```
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained('NousResearch/Hermes-2-Pro-Llama-3-8B')
>>> print(tokenizer.get_chat_template())
```
You can copy the output and paste it into the settings page.
4. (Optional) You can save the output to a text file like this:
```
>>> open('chat_template.txt', 'w').write(tokenizer.get_chat_template())
```
If you get a ValueError exception, this model does not provide a chat template. See the [below instructions][creating] on creating a chat template.
### Python (gated models)
Some models, such as Llama and Mistral, do not allow public access to their chat template. You must either use the CLI method above, or follow the following instructions to use Python:
1. For these steps, you must have git and git-lfs installed.
2. You must have a HuggingFace account and be logged in.
3. You must already have access to the gated model. Otherwise, request access.
4. You must have an SSH key configured for git access to HuggingFace.
5. `git clone` the model's HuggingFace repo using the SSH clone URL. There is no need to download the entire model, which is very large. A good way to do this on Linux is:
```console
$ GIT_LFS_SKIP_SMUDGE=1 git clone hf.co:meta-llama/Llama-3.1-8B-Instruct.git
$ cd Llama-3.1-8B-Instruct
$ git lfs pull -I "tokenizer.*"
```
6. Follow the above instructions for open models, but replace the model ID with the path to the directory containing `tokenizer\_config.json`:
```
>>> tokenizer = AutoTokenizer.from_pretrained('.')
```
## Advanced: How do chat templates work?
The chat template is applied to the entire conversation you see in the chat window. The template loops over the list of messages, each containing `role` and `content` fields. `role` is either `user`, `assistant`, or `system`.
GPT4All also supports the special variables `bos_token`, `eos_token`, and `add_generation_prompt`. See the [HuggingFace docs] for what those do.
[HuggingFace docs]: https://huggingface.co/docs/transformers/v4.46.3/en/chat_templating#special-variables
## Advanced: How do I make a chat template?
The best way to create a chat template is to start by using an existing one as a reference. Then, modify it to use the format documented for the given model. Its README page may explicitly give an example of its template. Or, it may mention the name of a well-known standard template, such as ChatML, Alpaca, Vicuna. GPT4All does not yet include presets for these templates, so they will have to be found in other models or taken from the community.
For more information, see the very helpful [HuggingFace guide]. Some of this is not applicable, such as the information about tool calling and RAG - GPT4All implements those features differently.
Some models use a prompt template that does not intuitively map to a multi-turn chat, because it is more intended for single instructions. The [FastChat] implementation of these templates is a useful reference for the correct way to extend them to multiple messages.
[HuggingFace guide]: https://huggingface.co/docs/transformers/v4.46.3/en/chat_templating#advanced-template-writing-tips
[FastChat]: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# Advanced: What are GPT4All v1 templates?
GPT4All supports its own template syntax, which is nonstandard but provides complete control over the way LocalDocs sources and file attachments are inserted into the conversation. These templates begin with `{# gpt4all v1 #}` and look similar to the example below.
For standard templates, GPT4All combines the user message, sources, and attachments into the `content` field. For GPT4All v1 templates, this is not done, so they must be used directly in the template for those features to work correctly.
```jinja
{# gpt4all v1 #}
{%- for message in messages %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
{%- if message['role'] == 'user' %}
{%- for source in message['sources'] %}
{%- if loop.first %}
{{- '### Context:\n' }}
{%- endif %}
{{- 'Collection: ' + source['collection'] + '\n' +
'Path: ' + source['path'] + '\n' +
'Excerpt: ' + source['text'] + '\n\n' }}
{%- endfor %}
{%- endif %}
{%- for attachment in message['prompt_attachments'] %}
{{- attachment['processed_content'] + '\n\n' }}
{%- endfor %}
{{- message['content'] | trim }}
{{- '<|eot_id|>' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
```

View File

@ -9,7 +9,7 @@ import textwrap
import threading import threading
from enum import Enum from enum import Enum
from queue import Queue from queue import Queue
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Literal, NoReturn, TypeVar, overload from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload
if sys.version_info >= (3, 9): if sys.version_info >= (3, 9):
import importlib.resources as importlib_resources import importlib.resources as importlib_resources
@ -23,7 +23,9 @@ else:
from typing import TypedDict from typing import TypedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import TypeAlias from typing_extensions import ParamSpec, TypeAlias
T = TypeVar("T")
P = ParamSpec("P")
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]') EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
@ -31,7 +33,7 @@ cuda_found: bool = False
# TODO(jared): use operator.call after we drop python 3.10 support # TODO(jared): use operator.call after we drop python 3.10 support
def _operator_call(obj, /, *args, **kwargs): def _operator_call(obj: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
return obj(*args, **kwargs) return obj(*args, **kwargs)
@ -116,16 +118,15 @@ llmodel = load_llmodel_library()
class LLModelPromptContext(ctypes.Structure): class LLModelPromptContext(ctypes.Structure):
_fields_ = [ _fields_ = [
("n_past", ctypes.c_int32), ("n_predict", ctypes.c_int32),
("n_predict", ctypes.c_int32), ("top_k", ctypes.c_int32),
("top_k", ctypes.c_int32), ("top_p", ctypes.c_float),
("top_p", ctypes.c_float), ("min_p", ctypes.c_float),
("min_p", ctypes.c_float), ("temp", ctypes.c_float),
("temp", ctypes.c_float), ("n_batch", ctypes.c_int32),
("n_batch", ctypes.c_int32),
("repeat_penalty", ctypes.c_float), ("repeat_penalty", ctypes.c_float),
("repeat_last_n", ctypes.c_int32), ("repeat_last_n", ctypes.c_int32),
("context_erase", ctypes.c_float), ("context_erase", ctypes.c_float),
] ]
@ -157,23 +158,21 @@ llmodel.llmodel_required_mem.restype = ctypes.c_size_t
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p] llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32) PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_int32), ctypes.c_size_t, ctypes.c_bool)
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p) ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p) EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
SpecialTokenCallback = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p)
llmodel.llmodel_prompt.argtypes = [ llmodel.llmodel_prompt.argtypes = [
ctypes.c_void_p, ctypes.c_void_p,
ctypes.c_char_p, ctypes.c_char_p,
ctypes.c_char_p,
PromptCallback, PromptCallback,
ResponseCallback, ResponseCallback,
ctypes.c_bool,
ctypes.POINTER(LLModelPromptContext), ctypes.POINTER(LLModelPromptContext),
ctypes.c_bool, ctypes.POINTER(ctypes.c_char_p),
ctypes.c_char_p,
] ]
llmodel.llmodel_prompt.restype = None llmodel.llmodel_prompt.restype = ctypes.c_bool
llmodel.llmodel_embed.argtypes = [ llmodel.llmodel_embed.argtypes = [
ctypes.c_void_p, ctypes.c_void_p,
@ -222,6 +221,12 @@ llmodel.llmodel_model_backend_name.restype = ctypes.c_char_p
llmodel.llmodel_model_gpu_device_name.argtypes = [ctypes.c_void_p] llmodel.llmodel_model_gpu_device_name.argtypes = [ctypes.c_void_p]
llmodel.llmodel_model_gpu_device_name.restype = ctypes.c_char_p llmodel.llmodel_model_gpu_device_name.restype = ctypes.c_char_p
llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char_p)]
llmodel.llmodel_count_prompt_tokens.restype = ctypes.c_int32
llmodel.llmodel_model_foreach_special_token.argtypes = [ctypes.c_void_p, SpecialTokenCallback]
llmodel.llmodel_model_foreach_special_token.restype = None
ResponseCallbackType = Callable[[int, str], bool] ResponseCallbackType = Callable[[int, str], bool]
RawResponseCallbackType = Callable[[int, bytes], bool] RawResponseCallbackType = Callable[[int, bytes], bool]
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]' EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
@ -266,7 +271,6 @@ class LLModel:
self.model_path = model_path.encode() self.model_path = model_path.encode()
self.n_ctx = n_ctx self.n_ctx = n_ctx
self.ngl = ngl self.ngl = ngl
self.context: LLModelPromptContext | None = None
self.buffer = bytearray() self.buffer = bytearray()
self.buff_expecting_cont_bytes: int = 0 self.buff_expecting_cont_bytes: int = 0
@ -286,6 +290,10 @@ class LLModel:
raise RuntimeError(f"Unable to instantiate model: {errmsg}") raise RuntimeError(f"Unable to instantiate model: {errmsg}")
self.model: ctypes.c_void_p | None = model self.model: ctypes.c_void_p | None = model
self.special_tokens_map: dict[str, str] = {}
llmodel.llmodel_model_foreach_special_token(
self.model, lambda n, t: self.special_tokens_map.__setitem__(n.decode(), t.decode()),
)
def __del__(self, llmodel=llmodel): def __del__(self, llmodel=llmodel):
if hasattr(self, 'model'): if hasattr(self, 'model'):
@ -312,6 +320,19 @@ class LLModel:
dev = llmodel.llmodel_model_gpu_device_name(self.model) dev = llmodel.llmodel_model_gpu_device_name(self.model)
return None if dev is None else dev.decode() return None if dev is None else dev.decode()
def count_prompt_tokens(self, prompt: str) -> int:
if self.model is None:
self._raise_closed()
err = ctypes.c_char_p()
n_tok = llmodel.llmodel_count_prompt_tokens(self.model, prompt, ctypes.byref(err))
if n_tok < 0:
s = err.value
errmsg = 'null' if s is None else s.decode()
raise RuntimeError(f'Unable to count prompt tokens: {errmsg}')
return n_tok
llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
@staticmethod @staticmethod
def list_gpus(mem_required: int = 0) -> list[str]: def list_gpus(mem_required: int = 0) -> list[str]:
""" """
@ -375,48 +396,6 @@ class LLModel:
raise Exception("Model not loaded") raise Exception("Model not loaded")
return llmodel.llmodel_threadCount(self.model) return llmodel.llmodel_threadCount(self.model)
def _set_context(
self,
n_predict: int = 4096,
top_k: int = 40,
top_p: float = 0.9,
min_p: float = 0.0,
temp: float = 0.1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = 0.75,
reset_context: bool = False,
):
if self.context is None:
context = LLModelPromptContext(
n_past=0,
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
min_p=min_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase,
)
self.context = context
else:
context = self.context
if reset_context:
self.context.n_past = 0
self.context.n_predict = n_predict
self.context.top_k = top_k
self.context.top_p = top_p
self.context.min_p = min_p
self.context.temp = temp
self.context.n_batch = n_batch
self.context.repeat_penalty = repeat_penalty
self.context.repeat_last_n = repeat_last_n
self.context.context_erase = context_erase
@overload @overload
def generate_embeddings( def generate_embeddings(
self, text: str, prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool, self, text: str, prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
@ -486,20 +465,18 @@ class LLModel:
def prompt_model( def prompt_model(
self, self,
prompt: str, prompt : str,
prompt_template: str, callback : ResponseCallbackType,
callback: ResponseCallbackType, n_predict : int = 4096,
n_predict: int = 4096, top_k : int = 40,
top_k: int = 40, top_p : float = 0.9,
top_p: float = 0.9, min_p : float = 0.0,
min_p: float = 0.0, temp : float = 0.1,
temp: float = 0.1, n_batch : int = 8,
n_batch: int = 8, repeat_penalty : float = 1.2,
repeat_penalty: float = 1.2, repeat_last_n : int = 10,
repeat_last_n: int = 10, context_erase : float = 0.75,
context_erase: float = 0.75, reset_context : bool = False,
reset_context: bool = False,
special: bool = False,
): ):
""" """
Generate response from model from a prompt. Generate response from model from a prompt.
@ -522,34 +499,38 @@ class LLModel:
self.buffer.clear() self.buffer.clear()
self.buff_expecting_cont_bytes = 0 self.buff_expecting_cont_bytes = 0
self._set_context( context = LLModelPromptContext(
n_predict=n_predict, n_predict = n_predict,
top_k=top_k, top_k = top_k,
top_p=top_p, top_p = top_p,
min_p=min_p, min_p = min_p,
temp=temp, temp = temp,
n_batch=n_batch, n_batch = n_batch,
repeat_penalty=repeat_penalty, repeat_penalty = repeat_penalty,
repeat_last_n=repeat_last_n, repeat_last_n = repeat_last_n,
context_erase=context_erase, context_erase = context_erase,
reset_context=reset_context,
) )
llmodel.llmodel_prompt( error_msg: bytes | None = None
def error_callback(msg: bytes) -> None:
nonlocal error_msg
error_msg = msg
err = ctypes.c_char_p()
if not llmodel.llmodel_prompt(
self.model, self.model,
ctypes.c_char_p(prompt.encode()), ctypes.c_char_p(prompt.encode()),
ctypes.c_char_p(prompt_template.encode()),
PromptCallback(self._prompt_callback), PromptCallback(self._prompt_callback),
ResponseCallback(self._callback_decoder(callback)), ResponseCallback(self._callback_decoder(callback)),
True, context,
self.context, ctypes.byref(err),
special, ):
ctypes.c_char_p(), s = err.value
) raise RuntimeError(f"prompt error: {'null' if s is None else s.decode()}")
def prompt_model_streaming( def prompt_model_streaming(
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs: Any,
) -> Iterable[str]: ) -> Iterator[str]:
if self.model is None: if self.model is None:
self._raise_closed() self._raise_closed()
@ -568,15 +549,15 @@ class LLModel:
return _generator_callback return _generator_callback
def run_llmodel_prompt(prompt: str, prompt_template: str, callback: ResponseCallbackType, **kwargs): def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
self.prompt_model(prompt, prompt_template, callback, **kwargs) self.prompt_model(prompt, callback, **kwargs)
output_queue.put(Sentinel.TERMINATING_SYMBOL) output_queue.put(Sentinel.TERMINATING_SYMBOL)
# Kick off llmodel_prompt in separate thread so we can return generator # Kick off llmodel_prompt in separate thread so we can return generator
# immediately # immediately
thread = threading.Thread( thread = threading.Thread(
target=run_llmodel_prompt, target=run_llmodel_prompt,
args=(prompt, prompt_template, _generator_callback_wrapper(callback)), args=(prompt, _generator_callback_wrapper(callback)),
kwargs=kwargs, kwargs=kwargs,
) )
thread.start() thread.start()
@ -631,5 +612,5 @@ class LLModel:
# Empty prompt callback # Empty prompt callback
@staticmethod @staticmethod
def _prompt_callback(token_id: int) -> bool: def _prompt_callback(token_ids: ctypes._Pointer[ctypes.c_int32], n_token_ids: int, cached: bool) -> bool:
return True return True

View File

@ -4,37 +4,66 @@ Python only API for running all GPT4All models.
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
import json
import os import os
import platform import platform
import re import re
import sys import sys
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime
from pathlib import Path from pathlib import Path
from types import TracebackType from types import TracebackType
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, NamedTuple, NoReturn, Protocol, TypedDict, overload
import jinja2
import requests import requests
from jinja2.sandbox import ImmutableSandboxedEnvironment
from requests.exceptions import ChunkedEncodingError from requests.exceptions import ChunkedEncodingError
from tqdm import tqdm from tqdm import tqdm
from urllib3.exceptions import IncompleteRead, ProtocolError from urllib3.exceptions import IncompleteRead, ProtocolError
from ._pyllmodel import (CancellationError as CancellationError, EmbCancelCallbackType, EmbedResult as EmbedResult, from ._pyllmodel import (CancellationError as CancellationError, EmbCancelCallbackType, EmbedResult as EmbedResult,
LLModel, ResponseCallbackType, empty_response_callback) LLModel, ResponseCallbackType, _operator_call, empty_response_callback)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self, TypeAlias from typing_extensions import Self, TypeAlias
if sys.platform == 'darwin': if sys.platform == "darwin":
import fcntl import fcntl
# TODO: move to config # TODO: move to config
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all" DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n" ConfigType: TypeAlias = "dict[str, Any]"
ConfigType: TypeAlias = 'dict[str, Any]' # Environment setup adapted from HF transformers
MessageType: TypeAlias = 'dict[str, str]' @_operator_call
def _jinja_env() -> ImmutableSandboxedEnvironment:
def raise_exception(message: str) -> NoReturn:
raise jinja2.exceptions.TemplateError(message)
def tojson(obj: Any, indent: int | None = None) -> str:
return json.dumps(obj, ensure_ascii=False, indent=indent)
def strftime_now(fmt: str) -> str:
return datetime.now().strftime(fmt)
env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
env.filters["tojson" ] = tojson
env.globals["raise_exception"] = raise_exception
env.globals["strftime_now" ] = strftime_now
return env
class MessageType(TypedDict):
role: str
content: str
class ChatSession(NamedTuple):
template: jinja2.Template
history: list[MessageType]
class Embed4All: class Embed4All:
@ -54,7 +83,7 @@ class Embed4All:
kwargs: Remaining keyword arguments are passed to the `GPT4All` constructor. kwargs: Remaining keyword arguments are passed to the `GPT4All` constructor.
""" """
if model_name is None: if model_name is None:
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf' model_name = "all-MiniLM-L6-v2.gguf2.f16.gguf"
self.gpt4all = GPT4All(model_name, n_threads=n_threads, device=device, **kwargs) self.gpt4all = GPT4All(model_name, n_threads=n_threads, device=device, **kwargs)
def __enter__(self) -> Self: def __enter__(self) -> Self:
@ -145,18 +174,18 @@ class Embed4All:
dimensionality = -1 dimensionality = -1
else: else:
if dimensionality <= 0: if dimensionality <= 0:
raise ValueError(f'Dimensionality must be None or a positive integer, got {dimensionality}') raise ValueError(f"Dimensionality must be None or a positive integer, got {dimensionality}")
if dimensionality < self.MIN_DIMENSIONALITY: if dimensionality < self.MIN_DIMENSIONALITY:
warnings.warn( warnings.warn(
f'Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}.' f"Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}."
' Performance may be degraded.' " Performance may be degraded."
) )
try: try:
do_mean = {"mean": True, "truncate": False}[long_text_mode] do_mean = {"mean": True, "truncate": False}[long_text_mode]
except KeyError: except KeyError:
raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}") raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}")
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas, cancel_cb) result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas, cancel_cb)
return result if return_dict else result['embeddings'] return result if return_dict else result["embeddings"]
class GPT4All: class GPT4All:
@ -204,8 +233,7 @@ class GPT4All:
""" """
self.model_type = model_type self.model_type = model_type
self._history: list[MessageType] | None = None self._chat_session: ChatSession | None = None
self._current_prompt_template: str = "{0}"
device_init = None device_init = None
if sys.platform == "darwin": if sys.platform == "darwin":
@ -264,7 +292,13 @@ class GPT4All:
@property @property
def current_chat_session(self) -> list[MessageType] | None: def current_chat_session(self) -> list[MessageType] | None:
return None if self._history is None else list(self._history) return None if self._chat_session is None else self._chat_session.history
@current_chat_session.setter
def current_chat_session(self, history: list[MessageType]) -> None:
if self._chat_session is None:
raise ValueError("current_chat_session may only be set when there is an active chat session")
self._chat_session.history[:] = history
@staticmethod @staticmethod
def list_models() -> list[ConfigType]: def list_models() -> list[ConfigType]:
@ -276,7 +310,7 @@ class GPT4All:
""" """
resp = requests.get("https://gpt4all.io/models/models3.json") resp = requests.get("https://gpt4all.io/models/models3.json")
if resp.status_code != 200: if resp.status_code != 200:
raise ValueError(f'Request failed: HTTP {resp.status_code} {resp.reason}') raise ValueError(f"Request failed: HTTP {resp.status_code} {resp.reason}")
return resp.json() return resp.json()
@classmethod @classmethod
@ -306,15 +340,9 @@ class GPT4All:
# get the config for the model # get the config for the model
config: ConfigType = {} config: ConfigType = {}
if allow_download: if allow_download:
available_models = cls.list_models() models = cls.list_models()
if (model := next((m for m in models if m["filename"] == model_filename), None)) is not None:
for m in available_models: config.update(model)
if model_filename == m["filename"]:
tmpl = m.get("promptTemplate", DEFAULT_PROMPT_TEMPLATE)
# change to Python-style formatting
m["promptTemplate"] = tmpl.replace("%1", "{0}", 1).replace("%2", "{1}", 1)
config.update(m)
break
# Validate download directory # Validate download directory
if model_path is None: if model_path is None:
@ -378,13 +406,13 @@ class GPT4All:
headers = {} headers = {}
if offset: if offset:
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr) print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
headers['Range'] = f'bytes={offset}-' # resume incomplete response headers["Range"] = f"bytes={offset}-" # resume incomplete response
headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges
response = requests.get(url, stream=True, headers=headers) response = requests.get(url, stream=True, headers=headers)
if response.status_code not in (200, 206): if response.status_code not in (200, 206):
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}') raise ValueError(f"Request failed: HTTP {response.status_code} {response.reason}")
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')): if offset and (response.status_code != 206 or str(offset) not in response.headers.get("Content-Range", "")):
raise ValueError('Connection was interrupted and server does not support range requests') raise ValueError("Connection was interrupted and server does not support range requests")
if (enc := response.headers.get("Content-Encoding")) is not None: if (enc := response.headers.get("Content-Encoding")) is not None:
raise ValueError(f"Expected identity Content-Encoding, got {enc}") raise ValueError(f"Expected identity Content-Encoding, got {enc}")
return response return response
@ -483,19 +511,19 @@ class GPT4All:
def generate( def generate(
self, self,
prompt: str, prompt : str,
*, *,
max_tokens: int = 200, max_tokens : int = 200,
temp: float = 0.7, temp : float = 0.7,
top_k: int = 40, top_k : int = 40,
top_p: float = 0.4, top_p : float = 0.4,
min_p: float = 0.0, min_p : float = 0.0,
repeat_penalty: float = 1.18, repeat_penalty : float = 1.18,
repeat_last_n: int = 64, repeat_last_n : int = 64,
n_batch: int = 8, n_batch : int = 8,
n_predict: int | None = None, n_predict : int | None = None,
streaming: bool = False, streaming : bool = False,
callback: ResponseCallbackType = empty_response_callback, callback : ResponseCallbackType = empty_response_callback,
) -> Any: ) -> Any:
""" """
Generate outputs from any GPT4All model. Generate outputs from any GPT4All model.
@ -520,122 +548,94 @@ class GPT4All:
# Preparing the model request # Preparing the model request
generate_kwargs: dict[str, Any] = dict( generate_kwargs: dict[str, Any] = dict(
temp=temp, temp = temp,
top_k=top_k, top_k = top_k,
top_p=top_p, top_p = top_p,
min_p=min_p, min_p = min_p,
repeat_penalty=repeat_penalty, repeat_penalty = repeat_penalty,
repeat_last_n=repeat_last_n, repeat_last_n = repeat_last_n,
n_batch=n_batch, n_batch = n_batch,
n_predict=n_predict if n_predict is not None else max_tokens, n_predict = n_predict if n_predict is not None else max_tokens,
) )
if self._history is not None:
# check if there is only one message, i.e. system prompt:
reset = len(self._history) == 1
self._history.append({"role": "user", "content": prompt})
fct_func = self._format_chat_prompt_template.__func__ # type: ignore[attr-defined]
if fct_func is GPT4All._format_chat_prompt_template:
if reset:
# ingest system prompt
# use "%1%2" and not "%1" to avoid implicit whitespace
self.model.prompt_model(self._history[0]["content"], "%1%2",
empty_response_callback,
n_batch=n_batch, n_predict=0, reset_context=True, special=True)
prompt_template = self._current_prompt_template.format("%1", "%2")
else:
warnings.warn(
"_format_chat_prompt_template is deprecated. Please use a chat session with a prompt template.",
DeprecationWarning,
)
# special tokens won't be processed
prompt = self._format_chat_prompt_template(
self._history[-1:],
self._history[0]["content"] if reset else "",
)
prompt_template = "%1"
generate_kwargs["reset_context"] = reset
else:
prompt_template = "%1"
generate_kwargs["reset_context"] = True
# Prepare the callback, process the model response # Prepare the callback, process the model response
output_collector: list[MessageType] full_response = ""
output_collector = [
{"content": ""}
] # placeholder for the self._history if chat session is not activated
if self._history is not None: def _callback_wrapper(token_id: int, response: str) -> bool:
self._history.append({"role": "assistant", "content": ""}) nonlocal full_response
output_collector = self._history full_response += response
return callback(token_id, response)
def _callback_wrapper( last_msg_rendered = prompt
callback: ResponseCallbackType, if self._chat_session is not None:
output_collector: list[MessageType], session = self._chat_session
) -> ResponseCallbackType: def render(messages: list[MessageType]) -> str:
def _callback(token_id: int, response: str) -> bool: return session.template.render(
nonlocal callback, output_collector messages=messages,
add_generation_prompt=True,
**self.model.special_tokens_map,
)
session.history.append(MessageType(role="user", content=prompt))
prompt = render(session.history)
if len(session.history) > 1:
last_msg_rendered = render(session.history[-1:])
output_collector[-1]["content"] += response # Check request length
last_msg_len = self.model.count_prompt_tokens(last_msg_rendered)
return callback(token_id, response) if last_msg_len > (limit := self.model.n_ctx - 4):
raise ValueError(f"Your message was too long and could not be processed ({last_msg_len} > {limit}).")
return _callback
# Send the request to the model # Send the request to the model
if streaming: if streaming:
return self.model.prompt_model_streaming( def stream() -> Iterator[str]:
prompt, yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs)
prompt_template, if self._chat_session is not None:
_callback_wrapper(callback, output_collector), self._chat_session.history.append(MessageType(role="assistant", content=full_response))
**generate_kwargs, return stream()
)
self.model.prompt_model( self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs)
prompt, if self._chat_session is not None:
prompt_template, self._chat_session.history.append(MessageType(role="assistant", content=full_response))
_callback_wrapper(callback, output_collector), return full_response
**generate_kwargs,
)
return output_collector[-1]["content"]
@contextmanager @contextmanager
def chat_session( def chat_session(
self, self,
system_prompt: str | None = None, system_message: str | Literal[False] | None = None,
prompt_template: str | None = None, chat_template: str | None = None,
): ):
""" """
Context manager to hold an inference optimized chat session with a GPT4All model. Context manager to hold an inference optimized chat session with a GPT4All model.
Args: Args:
system_prompt: An initial instruction for the model. system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None.
prompt_template: Template for the prompts with {0} being replaced by the user message. chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
""" """
if system_prompt is None: if system_message is None:
system_prompt = self.config.get("systemPrompt", "") system_message = self.config.get("systemMessage", False)
if prompt_template is None: if chat_template is None:
if (tmpl := self.config.get("promptTemplate")) is None: if "name" not in self.config:
warnings.warn("Use of a sideloaded model or allow_download=False without specifying a prompt template " raise ValueError("For sideloaded models or with allow_download=False, you must specify a chat template.")
"is deprecated. Defaulting to Alpaca.", DeprecationWarning) if "chatTemplate" not in self.config:
tmpl = DEFAULT_PROMPT_TEMPLATE raise NotImplementedError("This model appears to have a built-in chat template, but loading it is not "
prompt_template = tmpl "currently implemented. Please pass a template to chat_session() directly.")
if (tmpl := self.config["chatTemplate"]) is None:
raise ValueError(f"The model {self.config['name']!r} does not support chat.")
chat_template = tmpl
if re.search(r"%1(?![0-9])", prompt_template): history = []
raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt " if system_message is not False:
"placeholder, please use '{0}' instead.") history.append(MessageType(role="system", content=system_message))
self._chat_session = ChatSession(
self._history = [{"role": "system", "content": system_prompt}] template=_jinja_env.from_string(chat_template),
self._current_prompt_template = prompt_template history=history,
)
try: try:
yield self yield self
finally: finally:
self._history = None self._chat_session = None
self._current_prompt_template = "{0}"
@staticmethod @staticmethod
def list_gpus() -> list[str]: def list_gpus() -> list[str]:
@ -647,43 +647,6 @@ class GPT4All:
""" """
return LLModel.list_gpus() return LLModel.list_gpus()
def _format_chat_prompt_template(
self,
messages: list[MessageType],
default_prompt_header: str = "",
default_prompt_footer: str = "",
) -> str:
"""
Helper method for building a prompt from list of messages using the self._current_prompt_template as a template for each message.
Warning:
This function was deprecated in version 2.3.0, and will be removed in a future release.
Args:
messages: List of dictionaries. Each dictionary should have a "role" key
with value of "system", "assistant", or "user" and a "content" key with a
string value. Messages are organized such that "system" messages are at top of prompt,
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
"Response: {content}".
Returns:
Formatted prompt.
"""
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
for message in messages:
if message["role"] == "user":
user_message = self._current_prompt_template.format(message["content"])
full_prompt += user_message
if message["role"] == "assistant":
assistant_message = message["content"] + "\n"
full_prompt += assistant_message
full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else ""
return full_prompt
def append_extension_if_missing(model_name): def append_extension_if_missing(model_name):
if not model_name.endswith((".bin", ".gguf")): if not model_name.endswith((".bin", ".gguf")):
@ -696,7 +659,7 @@ class _HasFileno(Protocol):
def _fsync(fd: int | _HasFileno) -> None: def _fsync(fd: int | _HasFileno) -> None:
if sys.platform == 'darwin': if sys.platform == "darwin":
# Apple's fsync does not flush the drive write cache # Apple's fsync does not flush the drive write cache
try: try:
fcntl.fcntl(fd, fcntl.F_FULLFSYNC) fcntl.fcntl(fd, fcntl.F_FULLFSYNC)

View File

@ -14,6 +14,7 @@ nav:
- 'Models' : 'gpt4all_desktop/models.md' - 'Models' : 'gpt4all_desktop/models.md'
- 'LocalDocs' : 'gpt4all_desktop/localdocs.md' - 'LocalDocs' : 'gpt4all_desktop/localdocs.md'
- 'Settings' : 'gpt4all_desktop/settings.md' - 'Settings' : 'gpt4all_desktop/settings.md'
- 'Chat Templates' : 'gpt4all_desktop/chat_templates.md'
- 'Cookbook': - 'Cookbook':
- 'Local AI Chat with Microsoft Excel': 'gpt4all_desktop/cookbook/use-local-ai-models-to-privately-chat-with-microsoft-excel.md' - 'Local AI Chat with Microsoft Excel': 'gpt4all_desktop/cookbook/use-local-ai-models-to-privately-chat-with-microsoft-excel.md'
- 'Local AI Chat with your Google Drive': 'gpt4all_desktop/cookbook/use-local-ai-models-to-privately-chat-with-google-drive.md' - 'Local AI Chat with your Google Drive': 'gpt4all_desktop/cookbook/use-local-ai-models-to-privately-chat-with-google-drive.md'

View File

@ -88,9 +88,10 @@ setup(
python_requires='>=3.8', python_requires='>=3.8',
packages=find_packages(), packages=find_packages(),
install_requires=[ install_requires=[
'importlib_resources; python_version < "3.9"',
'jinja2~=3.1',
'requests', 'requests',
'tqdm', 'tqdm',
'importlib_resources; python_version < "3.9"',
'typing-extensions>=4.3.0; python_version >= "3.9" and python_version < "3.11"', 'typing-extensions>=4.3.0; python_version >= "3.9" and python_version < "3.11"',
], ],
extras_require={ extras_require={

View File

@ -190,6 +190,7 @@ qt_add_executable(chat
src/database.cpp src/database.h src/database.cpp src/database.h
src/download.cpp src/download.h src/download.cpp src/download.h
src/embllm.cpp src/embllm.h src/embllm.cpp src/embllm.h
src/jinja_helpers.cpp src/jinja_helpers.h
src/llm.cpp src/llm.h src/llm.cpp src/llm.h
src/localdocs.cpp src/localdocs.h src/localdocs.cpp src/localdocs.h
src/localdocsmodel.cpp src/localdocsmodel.h src/localdocsmodel.cpp src/localdocsmodel.h
@ -215,6 +216,7 @@ qt_add_qml_module(chat
qml/ApplicationSettings.qml qml/ApplicationSettings.qml
qml/ChatDrawer.qml qml/ChatDrawer.qml
qml/ChatItemView.qml qml/ChatItemView.qml
qml/ChatMessageButton.qml
qml/ChatView.qml qml/ChatView.qml
qml/CollectionsDrawer.qml qml/CollectionsDrawer.qml
qml/HomeView.qml qml/HomeView.qml
@ -227,7 +229,7 @@ qt_add_qml_module(chat
qml/PopupDialog.qml qml/PopupDialog.qml
qml/SettingsView.qml qml/SettingsView.qml
qml/StartupDialog.qml qml/StartupDialog.qml
qml/SwitchModelDialog.qml qml/ConfirmationDialog.qml
qml/Theme.qml qml/Theme.qml
qml/ThumbsDownDialog.qml qml/ThumbsDownDialog.qml
qml/Toast.qml qml/Toast.qml
@ -386,7 +388,7 @@ target_include_directories(chat PRIVATE deps/usearch/include
target_link_libraries(chat target_link_libraries(chat
PRIVATE Qt6::Core Qt6::HttpServer Qt6::Pdf Qt6::Quick Qt6::Sql Qt6::Svg) PRIVATE Qt6::Core Qt6::HttpServer Qt6::Pdf Qt6::Quick Qt6::Sql Qt6::Svg)
target_link_libraries(chat target_link_libraries(chat
PRIVATE llmodel SingleApplication fmt::fmt duckx::duckx QXlsx) PRIVATE llmodel SingleApplication fmt::fmt duckx::duckx QXlsx jinja2cpp)
if (APPLE) if (APPLE)
target_link_libraries(chat PRIVATE ${COCOA_LIBRARY}) target_link_libraries(chat PRIVATE ${COCOA_LIBRARY})

View File

@ -11,3 +11,5 @@ add_subdirectory(DuckX)
set(QT_VERSION_MAJOR 6) set(QT_VERSION_MAJOR 6)
add_subdirectory(QXlsx/QXlsx) add_subdirectory(QXlsx/QXlsx)
add_subdirectory(Jinja2Cpp)

@ -0,0 +1 @@
Subproject commit b2a716798bfa63c7dae303fc1e272964c4e1f9ee

View File

@ -1,3 +1 @@
<svg width="32" height="32" viewBox="0 0 32 32" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" fill="#000000" viewBox="0 0 256 256"><path d="M227.31,73.37,182.63,28.68a16,16,0,0,0-22.63,0L36.69,152A15.86,15.86,0,0,0,32,163.31V208a16,16,0,0,0,16,16H92.69A15.86,15.86,0,0,0,104,219.31L227.31,96a16,16,0,0,0,0-22.63ZM92.69,208H48V163.31l88-88L180.69,120ZM192,108.68,147.31,64l24-24L216,84.68Z"></path></svg>
<path d="M28.4138 9.17125L22.8288 3.585C22.643 3.39924 22.4225 3.25188 22.1799 3.15134C21.9372 3.0508 21.6771 2.99905 21.4144 2.99905C21.1517 2.99905 20.8916 3.0508 20.6489 3.15134C20.4062 3.25188 20.1857 3.39924 20 3.585L4.58626 19C4.39973 19.185 4.25185 19.4053 4.15121 19.648C4.05057 19.8907 3.99917 20.151 4.00001 20.4138V26C4.00001 26.5304 4.21072 27.0391 4.5858 27.4142C4.96087 27.7893 5.46958 28 6.00001 28H11.5863C11.849 28.0008 12.1093 27.9494 12.352 27.8488C12.5947 27.7482 12.815 27.6003 13 27.4138L28.4138 12C28.5995 11.8143 28.7469 11.5938 28.8474 11.3511C28.948 11.1084 28.9997 10.8483 28.9997 10.5856C28.9997 10.3229 28.948 10.0628 28.8474 9.82015C28.7469 9.57747 28.5995 9.35698 28.4138 9.17125ZM6.41376 20L17 9.41375L19.0863 11.5L8.50001 22.085L6.41376 20ZM6.00001 22.4138L9.58626 26H6.00001V22.4138ZM12 25.5863L9.91376 23.5L20.5 12.9138L22.5863 15L12 25.5863ZM24 13.5863L18.4138 8L21.4138 5L27 10.585L24 13.5863Z" fill="black"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.0 KiB

After

Width:  |  Height:  |  Size: 372 B

View File

@ -29,7 +29,8 @@
"description": "<ul><li>Fast responses</li><li>Instruct model</li><li>Multilingual dialogue use</li><li>Agentic system capable</li><li>Trained by Meta</li><li>License: <a href=\"https://llama.meta.com/llama3_2/license/\">Meta Llama 3.2 Community License</a></li></ul>", "description": "<ul><li>Fast responses</li><li>Instruct model</li><li>Multilingual dialogue use</li><li>Agentic system capable</li><li>Trained by Meta</li><li>License: <a href=\"https://llama.meta.com/llama3_2/license/\">Meta Llama 3.2 Community License</a></li></ul>",
"url": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_0.gguf", "url": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_0.gguf",
"promptTemplate": "<|start_header_id|>user<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2", "promptTemplate": "<|start_header_id|>user<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2",
"systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>" "systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>",
"chatTemplate": "{{- bos_token }}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{%- for message in messages %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
}, },
{ {
"order": "c", "order": "c",
@ -45,7 +46,8 @@
"description": "<ul><li>Fast responses</li><li>Instruct model</li><li>Multilingual dialogue use</li><li>Agentic system capable</li><li>Trained by Meta</li><li>License: <a href=\"https://llama.meta.com/llama3_2/license/\">Meta Llama 3.2 Community License</a></li></ul>", "description": "<ul><li>Fast responses</li><li>Instruct model</li><li>Multilingual dialogue use</li><li>Agentic system capable</li><li>Trained by Meta</li><li>License: <a href=\"https://llama.meta.com/llama3_2/license/\">Meta Llama 3.2 Community License</a></li></ul>",
"url": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf", "url": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf",
"promptTemplate": "<|start_header_id|>user<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2", "promptTemplate": "<|start_header_id|>user<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2",
"systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>" "systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>",
"chatTemplate": "{{- bos_token }}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{%- for message in messages %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
}, },
{ {
"order": "d", "order": "d",
@ -77,7 +79,8 @@
"systemPrompt": "", "systemPrompt": "",
"description": "<strong>Strong overall fast instruction following model</strong><br><ul><li>Fast responses</li><li>Trained by Mistral AI<li>Uncensored</li><li>Licensed for commercial use</li></ul>", "description": "<strong>Strong overall fast instruction following model</strong><br><ul><li>Fast responses</li><li>Trained by Mistral AI<li>Uncensored</li><li>Licensed for commercial use</li></ul>",
"url": "https://gpt4all.io/models/gguf/mistral-7b-instruct-v0.1.Q4_0.gguf", "url": "https://gpt4all.io/models/gguf/mistral-7b-instruct-v0.1.Q4_0.gguf",
"promptTemplate": "[INST] %1 [/INST]" "promptTemplate": "[INST] %1 [/INST]",
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_start = 1 %}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if (message['role'] == 'user') != ((loop.index0 - loop_start) % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}"
}, },
{ {
"order": "f", "order": "f",
@ -125,7 +128,8 @@
"systemPrompt": "", "systemPrompt": "",
"description": "<strong>Very fast model with good quality</strong><br><ul><li>Fastest responses</li><li>Instruction based</li><li>Trained by TII<li>Finetuned by Nomic AI<li>Licensed for commercial use</ul>", "description": "<strong>Very fast model with good quality</strong><br><ul><li>Fastest responses</li><li>Instruction based</li><li>Trained by TII<li>Finetuned by Nomic AI<li>Licensed for commercial use</ul>",
"url": "https://gpt4all.io/models/gguf/gpt4all-falcon-newbpe-q4_0.gguf", "url": "https://gpt4all.io/models/gguf/gpt4all-falcon-newbpe-q4_0.gguf",
"promptTemplate": "### Instruction:\n%1\n\n### Response:\n" "promptTemplate": "### Instruction:\n%1\n\n### Response:\n",
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### User: ' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Assistant: ' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Assistant:' }}\n{%- endif %}"
}, },
{ {
"order": "i", "order": "i",
@ -140,7 +144,8 @@
"type": "LLaMA2", "type": "LLaMA2",
"systemPrompt": "", "systemPrompt": "",
"description": "<ul><li>Instruction based<li>Trained by Microsoft<li>Cannot be used commercially</ul>", "description": "<ul><li>Instruction based<li>Trained by Microsoft<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/orca-2-7b.Q4_0.gguf" "url": "https://gpt4all.io/models/gguf/orca-2-7b.Q4_0.gguf",
"chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}"
}, },
{ {
"order": "j", "order": "j",
@ -155,7 +160,8 @@
"type": "LLaMA2", "type": "LLaMA2",
"systemPrompt": "", "systemPrompt": "",
"description": "<ul><li>Instruction based<li>Trained by Microsoft<li>Cannot be used commercially</ul>", "description": "<ul><li>Instruction based<li>Trained by Microsoft<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/orca-2-13b.Q4_0.gguf" "url": "https://gpt4all.io/models/gguf/orca-2-13b.Q4_0.gguf",
"chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}"
}, },
{ {
"order": "k", "order": "k",
@ -170,7 +176,9 @@
"type": "LLaMA2", "type": "LLaMA2",
"systemPrompt": "", "systemPrompt": "",
"description": "<strong>Strong overall larger model</strong><br><ul><li>Instruction based<li>Gives very long responses<li>Finetuned with only 1k of high-quality data<li>Trained by Microsoft and Peking University<li>Cannot be used commercially</ul>", "description": "<strong>Strong overall larger model</strong><br><ul><li>Instruction based<li>Gives very long responses<li>Finetuned with only 1k of high-quality data<li>Trained by Microsoft and Peking University<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/wizardlm-13b-v1.2.Q4_0.gguf" "url": "https://gpt4all.io/models/gguf/wizardlm-13b-v1.2.Q4_0.gguf",
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + ' ' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in loop_messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- 'USER: ' + message['content'] }}\n {%- elif message['role'] == 'assistant' %}\n {{- 'ASSISTANT: ' + message['content'] }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- if (loop.index0 - loop_start) % 2 == 0 %}\n {{- ' ' }}\n {%- else %}\n {{- eos_token }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- 'ASSISTANT:' }}\n{%- endif %}",
"systemMessage": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
}, },
{ {
"order": "l", "order": "l",
@ -186,7 +194,8 @@
"description": "<strong>Ghost 7B v0.9.1</strong> fast, powerful and smooth for Vietnamese and English languages.", "description": "<strong>Ghost 7B v0.9.1</strong> fast, powerful and smooth for Vietnamese and English languages.",
"url": "https://huggingface.co/lamhieu/ghost-7b-v0.9.1-gguf/resolve/main/ghost-7b-v0.9.1-Q4_0.gguf", "url": "https://huggingface.co/lamhieu/ghost-7b-v0.9.1-gguf/resolve/main/ghost-7b-v0.9.1-Q4_0.gguf",
"promptTemplate": "<|user|>\n%1</s>\n<|assistant|>\n%2</s>\n", "promptTemplate": "<|user|>\n%1</s>\n<|assistant|>\n%2</s>\n",
"systemPrompt": "<|system|>\nYou are Ghost created by Lam Hieu. You are a helpful and knowledgeable assistant. You like to help and always give honest information, in its original language. In communication, you are always respectful, equal and promote positive behavior.\n</s>" "systemPrompt": "<|system|>\nYou are Ghost created by Lam Hieu. You are a helpful and knowledgeable assistant. You like to help and always give honest information, in its original language. In communication, you are always respectful, equal and promote positive behavior.\n</s>",
"systemMessage": "You are Ghost created by Lam Hieu. You are a helpful and knowledgeable assistant. You like to help and always give honest information, in its original language. In communication, you are always respectful, equal and promote positive behavior."
}, },
{ {
"order": "m", "order": "m",
@ -202,7 +211,8 @@
"systemPrompt": "", "systemPrompt": "",
"description": "<strong>Extremely good model</strong><br><ul><li>Instruction based<li>Gives long responses<li>Curated with 300,000 uncensored instructions<li>Trained by Nous Research<li>Cannot be used commercially</ul>", "description": "<strong>Extremely good model</strong><br><ul><li>Instruction based<li>Gives long responses<li>Curated with 300,000 uncensored instructions<li>Trained by Nous Research<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/nous-hermes-llama2-13b.Q4_0.gguf", "url": "https://gpt4all.io/models/gguf/nous-hermes-llama2-13b.Q4_0.gguf",
"promptTemplate": "### Instruction:\n%1\n\n### Response:\n" "promptTemplate": "### Instruction:\n%1\n\n### Response:\n",
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### Instruction:\\n' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Response:\\n' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Instruction:\\n' }}\n{%- endif %}"
}, },
{ {
"order": "n", "order": "n",
@ -217,7 +227,9 @@
"type": "LLaMA", "type": "LLaMA",
"systemPrompt": "", "systemPrompt": "",
"description": "<strong>Very good overall model</strong><br><ul><li>Instruction based<li>Based on the same dataset as Groovy<li>Slower than Groovy, with higher quality responses<li>Trained by Nomic AI<li>Cannot be used commercially</ul>", "description": "<strong>Very good overall model</strong><br><ul><li>Instruction based<li>Based on the same dataset as Groovy<li>Slower than Groovy, with higher quality responses<li>Trained by Nomic AI<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/gpt4all-13b-snoozy-q4_0.gguf" "url": "https://gpt4all.io/models/gguf/gpt4all-13b-snoozy-q4_0.gguf",
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### Instruction:\\n' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Response:\\n' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Response:\\n' }}\n{%- endif %}",
"systemMessage": "Below is an instruction that describes a task. Write a response that appropriately completes the request."
}, },
{ {
"order": "o", "order": "o",
@ -234,7 +246,8 @@
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>", "description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat-newbpe-q4_0.gguf", "url": "https://gpt4all.io/models/gguf/mpt-7b-chat-newbpe-q4_0.gguf",
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n", "promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n" "systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n",
"chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}"
}, },
{ {
"order": "p", "order": "p",
@ -250,7 +263,8 @@
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>", "description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat.gguf4.Q4_0.gguf", "url": "https://gpt4all.io/models/gguf/mpt-7b-chat.gguf4.Q4_0.gguf",
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n", "promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n" "systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n",
"chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}"
}, },
{ {
"order": "q", "order": "q",
@ -282,7 +296,8 @@
"description": "<strong>Small version of new model with novel dataset</strong><br><ul><li>Very fast responses</li><li>Instruction based</li><li>Explain tuned datasets</li><li>Orca Research Paper dataset construction approaches</li><li>Cannot be used commercially</li></ul>", "description": "<strong>Small version of new model with novel dataset</strong><br><ul><li>Very fast responses</li><li>Instruction based</li><li>Explain tuned datasets</li><li>Orca Research Paper dataset construction approaches</li><li>Cannot be used commercially</li></ul>",
"url": "https://gpt4all.io/models/gguf/orca-mini-3b-gguf2-q4_0.gguf", "url": "https://gpt4all.io/models/gguf/orca-mini-3b-gguf2-q4_0.gguf",
"promptTemplate": "### User:\n%1\n\n### Response:\n", "promptTemplate": "### User:\n%1\n\n### Response:\n",
"systemPrompt": "### System:\nYou are an AI assistant that follows instruction extremely well. Help as much as you can.\n\n" "systemPrompt": "### System:\nYou are an AI assistant that follows instruction extremely well. Help as much as you can.\n\n",
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- '### System:\\n' + messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### User:\\n' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Response:\\n' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Response:\\n' }}\n{%- endif %}"
}, },
{ {
"order": "s", "order": "s",
@ -299,7 +314,8 @@
"systemPrompt": "", "systemPrompt": "",
"promptTemplate": "%1", "promptTemplate": "%1",
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>Licensed for commercial use<li>WARNING: Not available for chat GUI</ul>", "description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>Licensed for commercial use<li>WARNING: Not available for chat GUI</ul>",
"url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-newbpe-q4_0.gguf" "url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-newbpe-q4_0.gguf",
"chatTemplate": null
}, },
{ {
"order": "t", "order": "t",
@ -316,7 +332,8 @@
"systemPrompt": "", "systemPrompt": "",
"promptTemplate": "%1", "promptTemplate": "%1",
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</ul>", "description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</ul>",
"url": "https://gpt4all.io/models/gguf/starcoder-newbpe-q4_0.gguf" "url": "https://gpt4all.io/models/gguf/starcoder-newbpe-q4_0.gguf",
"chatTemplate": null
}, },
{ {
"order": "u", "order": "u",
@ -333,7 +350,8 @@
"systemPrompt": "", "systemPrompt": "",
"promptTemplate": "%1", "promptTemplate": "%1",
"description": "<strong>Trained on collection of Python and TypeScript</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</li>", "description": "<strong>Trained on collection of Python and TypeScript</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</li>",
"url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf" "url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf",
"chatTemplate": null
}, },
{ {
"order": "v", "order": "v",
@ -351,7 +369,8 @@
"embeddingModel": true, "embeddingModel": true,
"systemPrompt": "", "systemPrompt": "",
"description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)", "description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf" "url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf",
"chatTemplate": null
}, },
{ {
"order": "w", "order": "w",
@ -367,7 +386,8 @@
"type": "Bert", "type": "Bert",
"embeddingModel": true, "embeddingModel": true,
"description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)", "description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2.gguf2.f16.gguf" "url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2.gguf2.f16.gguf",
"chatTemplate": null
}, },
{ {
"order": "x", "order": "x",
@ -383,7 +403,9 @@
"description": "<strong>Mistral-based model for German-language applications</strong><br><ul><li>Fast responses</li><li>Chat based model</li><li>Trained by ellamind<li>Finetuned on German instruction and chat data</a><li>Licensed for commercial use</ul>", "description": "<strong>Mistral-based model for German-language applications</strong><br><ul><li>Fast responses</li><li>Chat based model</li><li>Trained by ellamind<li>Finetuned on German instruction and chat data</a><li>Licensed for commercial use</ul>",
"url": "https://huggingface.co/TheBloke/em_german_mistral_v01-GGUF/resolve/main/em_german_mistral_v01.Q4_0.gguf", "url": "https://huggingface.co/TheBloke/em_german_mistral_v01-GGUF/resolve/main/em_german_mistral_v01.Q4_0.gguf",
"promptTemplate": "USER: %1 ASSISTANT: ", "promptTemplate": "USER: %1 ASSISTANT: ",
"systemPrompt": "Du bist ein hilfreicher Assistent. " "systemPrompt": "Du bist ein hilfreicher Assistent. ",
"chatTemplate": "{%- set system_message = false %}\n{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {%- set system_message = true %}\n {{- messages[0]['content'] }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if (not loop.first) or (system_message is not none) %}\n {{- ' ' }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {{- 'USER: ' + message['content'] }}\n {%- elif message['role'] == 'assistant' %}\n {{- 'ASSISTANT: ' + message['content'] }}\n {%- else %}\n {{- raise_exception('After the optional system message, conversation roles must be either user or assistant.') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {%- if messages %}\n {{- ' ' }}\n {%- endif %}\n {{- 'ASSISTANT:' }}\n{%- endif %}",
"systemMessage": "Du bist ein hilfreicher Assistent."
}, },
{ {
"order": "y", "order": "y",
@ -400,7 +422,8 @@
"embeddingModel": true, "embeddingModel": true,
"systemPrompt": "", "systemPrompt": "",
"description": "nomic-embed-text-v1", "description": "nomic-embed-text-v1",
"url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.f16.gguf" "url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.f16.gguf",
"chatTemplate": null
}, },
{ {
"order": "z", "order": "z",
@ -417,7 +440,8 @@
"embeddingModel": true, "embeddingModel": true,
"systemPrompt": "", "systemPrompt": "",
"description": "nomic-embed-text-v1.5", "description": "nomic-embed-text-v1.5",
"url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.5.f16.gguf" "url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.5.f16.gguf",
"chatTemplate": null
}, },
{ {
"order": "zzz", "order": "zzz",

View File

@ -10,7 +10,7 @@ import network
import llm import llm
MySettingsTab { MySettingsTab {
onRestoreDefaultsClicked: { onRestoreDefaults: {
MySettings.restoreApplicationDefaults(); MySettings.restoreApplicationDefaults();
} }
title: qsTr("Application") title: qsTr("Application")
@ -486,23 +486,6 @@ MySettingsTab {
Accessible.name: nThreadsLabel.text Accessible.name: nThreadsLabel.text
Accessible.description: ToolTip.text Accessible.description: ToolTip.text
} }
MySettingsLabel {
id: saveChatsContextLabel
text: qsTr("Save Chat Context")
helpText: qsTr("Save the chat model's state to disk for faster loading. WARNING: Uses ~2GB per chat.")
Layout.row: 12
Layout.column: 0
}
MyCheckBox {
id: saveChatsContextBox
Layout.row: 12
Layout.column: 2
Layout.alignment: Qt.AlignRight
checked: MySettings.saveChatsContext
onClicked: {
MySettings.saveChatsContext = !MySettings.saveChatsContext
}
}
MySettingsLabel { MySettingsLabel {
id: trayLabel id: trayLabel
text: qsTr("Enable System Tray") text: qsTr("Enable System Tray")

View File

@ -8,8 +8,23 @@ import QtQuick.Layouts
import gpt4all import gpt4all
import mysettings import mysettings
ColumnLayout {
property var inputBoxText: null
signal setInputBoxText(text: string)
Item {
Layout.fillWidth: true
Layout.maximumWidth: parent.width
Layout.preferredHeight: gridLayout.height
HoverHandler { id: hoverArea }
GridLayout { GridLayout {
rows: 5 id: gridLayout
anchors.left: parent.left
anchors.right: parent.right
columns: 2 columns: 2
Item { Item {
@ -40,7 +55,7 @@ GridLayout {
to: 360 to: 360
duration: 1000 duration: 1000
loops: Animation.Infinite loops: Animation.Infinite
running: currentResponse && (currentChat.responseInProgress || currentChat.restoringFromText) running: isCurrentResponse && currentChat.responseInProgress
} }
} }
} }
@ -73,13 +88,11 @@ GridLayout {
color: theme.mutedTextColor color: theme.mutedTextColor
} }
RowLayout { RowLayout {
visible: currentResponse && ((value === "" && currentChat.responseInProgress) || currentChat.restoringFromText) visible: isCurrentResponse && (value === "" && currentChat.responseInProgress)
Text { Text {
color: theme.mutedTextColor color: theme.mutedTextColor
font.pixelSize: theme.fontSizeLarger font.pixelSize: theme.fontSizeLarger
text: { text: {
if (currentChat.restoringFromText)
return qsTr("restoring from text ...");
switch (currentChat.responseState) { switch (currentChat.responseState) {
case Chat.ResponseStopped: return qsTr("response stopped ..."); case Chat.ResponseStopped: return qsTr("response stopped ...");
case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: %1 ...").arg(currentChat.collectionList.join(", ")); case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: %1 ...").arg(currentChat.collectionList.join(", "));
@ -99,10 +112,11 @@ GridLayout {
Layout.row: 1 Layout.row: 1
Layout.column: 1 Layout.column: 1
Layout.fillWidth: true Layout.fillWidth: true
spacing: 20 spacing: 10
Flow { Flow {
id: attachedUrlsFlow id: attachedUrlsFlow
Layout.fillWidth: true Layout.fillWidth: true
Layout.bottomMargin: 10
spacing: 10 spacing: 10
visible: promptAttachments.length !== 0 visible: promptAttachments.length !== 0
Repeater { Repeater {
@ -156,7 +170,7 @@ GridLayout {
focus: false focus: false
readOnly: true readOnly: true
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
cursorVisible: currentResponse ? currentChat.responseInProgress : false cursorVisible: isCurrentResponse ? currentChat.responseInProgress : false
cursorPosition: text.length cursorPosition: text.length
TapHandler { TapHandler {
id: tapHandler id: tapHandler
@ -183,12 +197,12 @@ GridLayout {
} }
onLinkActivated: function(link) { onLinkActivated: function(link) {
if (!currentResponse || !currentChat.responseInProgress) if (!isCurrentResponse || !currentChat.responseInProgress)
Qt.openUrlExternally(link) Qt.openUrlExternally(link)
} }
onLinkHovered: function (link) { onLinkHovered: function (link) {
if (!currentResponse || !currentChat.responseInProgress) if (!isCurrentResponse || !currentChat.responseInProgress)
statusBar.externalHoveredLink = link statusBar.externalHoveredLink = link
} }
@ -239,13 +253,19 @@ GridLayout {
textProcessor.setValue(value); textProcessor.setValue(value);
} }
property bool textProcessorReady: false
Component.onCompleted: { Component.onCompleted: {
resetChatViewTextProcessor(); resetChatViewTextProcessor();
chatModel.valueChanged.connect(function(i, value) { textProcessorReady = true;
if (index === i) }
Connections {
target: chatModel
function onValueChanged(i, value) {
if (myTextArea.textProcessorReady && index === i)
textProcessor.setValue(value); textProcessor.setValue(value);
} }
);
} }
Connections { Connections {
@ -282,67 +302,6 @@ GridLayout {
Network.sendConversation(currentChat.id, getConversationJson()); Network.sendConversation(currentChat.id, getConversationJson());
} }
} }
Column {
Layout.alignment: Qt.AlignRight
Layout.rightMargin: 15
visible: name === "Response: " &&
(!currentResponse || !currentChat.responseInProgress) && MySettings.networkIsActive
spacing: 10
Item {
width: childrenRect.width
height: childrenRect.height
MyToolButton {
id: thumbsUp
width: 24
height: 24
imageWidth: width
imageHeight: height
opacity: thumbsUpState || thumbsUpState == thumbsDownState ? 1.0 : 0.2
source: "qrc:/gpt4all/icons/thumbs_up.svg"
Accessible.name: qsTr("Thumbs up")
Accessible.description: qsTr("Gives a thumbs up to the response")
onClicked: {
if (thumbsUpState && !thumbsDownState)
return
chatModel.updateNewResponse(index, "")
chatModel.updateThumbsUpState(index, true)
chatModel.updateThumbsDownState(index, false)
Network.sendConversation(currentChat.id, getConversationJson());
}
}
MyToolButton {
id: thumbsDown
anchors.top: thumbsUp.top
anchors.topMargin: 3
anchors.left: thumbsUp.right
anchors.leftMargin: 3
width: 24
height: 24
imageWidth: width
imageHeight: height
checked: thumbsDownState
opacity: thumbsDownState || thumbsUpState == thumbsDownState ? 1.0 : 0.2
transform: [
Matrix4x4 {
matrix: Qt.matrix4x4(-1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1)
},
Translate {
x: thumbsDown.width
}
]
source: "qrc:/gpt4all/icons/thumbs_down.svg"
Accessible.name: qsTr("Thumbs down")
Accessible.description: qsTr("Opens thumbs down dialog")
onClicked: {
thumbsDownDialog.open()
}
}
}
}
} }
Item { Item {
@ -353,11 +312,13 @@ GridLayout {
Layout.preferredWidth: childrenRect.width Layout.preferredWidth: childrenRect.width
Layout.preferredHeight: childrenRect.height Layout.preferredHeight: childrenRect.height
visible: { visible: {
if (name !== "Response: ")
return false
if (consolidatedSources.length === 0) if (consolidatedSources.length === 0)
return false return false
if (!MySettings.localDocsShowReferences) if (!MySettings.localDocsShowReferences)
return false return false
if (currentResponse && currentChat.responseInProgress if (isCurrentResponse && currentChat.responseInProgress
&& currentChat.responseState !== Chat.GeneratingQuestions ) && currentChat.responseState !== Chat.GeneratingQuestions )
return false return false
return true return true
@ -443,7 +404,7 @@ GridLayout {
return false return false
if (!MySettings.localDocsShowReferences) if (!MySettings.localDocsShowReferences)
return false return false
if (currentResponse && currentChat.responseInProgress if (isCurrentResponse && currentChat.responseInProgress
&& currentChat.responseState !== Chat.GeneratingQuestions ) && currentChat.responseState !== Chat.GeneratingQuestions )
return false return false
return true return true
@ -566,8 +527,139 @@ GridLayout {
} }
} }
ConfirmationDialog {
id: editPromptDialog
dialogTitle: qsTr("Edit this prompt?")
description: qsTr("The existing response and all later messages will be permanently erased.")
onAccepted: {
const msg = currentChat.popPrompt(index);
if (msg !== null)
setInputBoxText(msg);
}
}
ConfirmationDialog {
id: redoResponseDialog
dialogTitle: qsTr("Redo this response?")
description: qsTr("The existing response and all later messages will be permanently erased.")
onAccepted: currentChat.regenerateResponse(index)
}
RowLayout {
id: buttonRow
Layout.row: 4
Layout.column: 1
Layout.maximumWidth: parent.width
Layout.fillWidth: false
Layout.alignment: Qt.AlignLeft | Qt.AlignTop
spacing: 3
visible: !isCurrentResponse || !currentChat.responseInProgress
enabled: opacity > 0
opacity: hoverArea.hovered
readonly property var canModify: !currentChat.isServer && currentChat.isModelLoaded && !currentChat.responseInProgress
Behavior on opacity {
OpacityAnimator { duration: 30 }
}
ChatMessageButton {
visible: parent.canModify && model.name === "Prompt: "
Layout.maximumWidth: 24
Layout.maximumHeight: 24
Layout.alignment: Qt.AlignVCenter
Layout.fillWidth: false
source: "qrc:/gpt4all/icons/edit.svg"
onClicked: {
if (inputBoxText === "")
editPromptDialog.open();
}
name: qsTr("Edit")
}
ChatMessageButton {
visible: parent.canModify && model.name === "Response: "
Layout.maximumWidth: 24
Layout.maximumHeight: 24
Layout.alignment: Qt.AlignVCenter
Layout.fillWidth: false
name: qsTr("Redo")
source: "qrc:/gpt4all/icons/regenerate.svg"
onClicked: redoResponseDialog.open()
}
ChatMessageButton {
Layout.maximumWidth: 24
Layout.maximumHeight: 24
Layout.alignment: Qt.AlignVCenter
Layout.fillWidth: false
name: qsTr("Copy")
source: "qrc:/gpt4all/icons/copy.svg"
onClicked: {
myTextArea.selectAll();
myTextArea.copy();
myTextArea.deselect();
}
}
Item {
visible: name === "Response: " && MySettings.networkIsActive
Layout.alignment: Qt.AlignVCenter
Layout.preferredWidth: childrenRect.width
Layout.preferredHeight: childrenRect.height
Layout.fillWidth: false
ChatMessageButton {
id: thumbsUp
anchors.left: parent.left
anchors.verticalCenter: parent.verticalCenter
opacity: thumbsUpState || thumbsUpState == thumbsDownState ? 1.0 : 0.2
source: "qrc:/gpt4all/icons/thumbs_up.svg"
name: qsTr("Like response")
onClicked: {
if (thumbsUpState && !thumbsDownState)
return
chatModel.updateNewResponse(index, "")
chatModel.updateThumbsUpState(index, true)
chatModel.updateThumbsDownState(index, false)
Network.sendConversation(currentChat.id, getConversationJson());
}
}
ChatMessageButton {
id: thumbsDown
anchors.top: thumbsUp.top
anchors.topMargin: buttonRow.spacing
anchors.left: thumbsUp.right
anchors.leftMargin: buttonRow.spacing
checked: thumbsDownState
opacity: thumbsDownState || thumbsUpState == thumbsDownState ? 1.0 : 0.2
bgTransform: [
Matrix4x4 {
matrix: Qt.matrix4x4(-1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1)
},
Translate {
x: thumbsDown.width
}
]
source: "qrc:/gpt4all/icons/thumbs_down.svg"
name: qsTr("Dislike response")
onClicked: {
thumbsDownDialog.open()
}
}
}
}
} // GridLayout
} // Item
GridLayout {
Layout.fillWidth: true
Layout.maximumWidth: parent.width
function shouldShowSuggestions() { function shouldShowSuggestions() {
if (!currentResponse) if (!isCurrentResponse)
return false; return false;
if (MySettings.suggestionMode === 2) // Off if (MySettings.suggestionMode === 2) // Off
return false; return false;
@ -577,8 +669,8 @@ GridLayout {
} }
Item { Item {
visible: shouldShowSuggestions() visible: parent.shouldShowSuggestions()
Layout.row: 4 Layout.row: 5
Layout.column: 0 Layout.column: 0
Layout.topMargin: 20 Layout.topMargin: 20
Layout.alignment: Qt.AlignVCenter | Qt.AlignRight Layout.alignment: Qt.AlignVCenter | Qt.AlignRight
@ -601,8 +693,8 @@ GridLayout {
} }
Item { Item {
visible: shouldShowSuggestions() visible: parent.shouldShowSuggestions()
Layout.row: 4 Layout.row: 5
Layout.column: 1 Layout.column: 1
Layout.topMargin: 20 Layout.topMargin: 20
Layout.fillWidth: true Layout.fillWidth: true
@ -627,8 +719,8 @@ GridLayout {
} }
ColumnLayout { ColumnLayout {
visible: shouldShowSuggestions() visible: parent.shouldShowSuggestions()
Layout.row: 5 Layout.row: 6
Layout.column: 1 Layout.column: 1
Layout.fillWidth: true Layout.fillWidth: true
Layout.minimumHeight: 1 Layout.minimumHeight: 1
@ -786,4 +878,7 @@ GridLayout {
} }
} }
} }
}
} // GridLayout
} // ColumnLayout

View File

@ -0,0 +1,20 @@
import QtQuick
import QtQuick.Controls
import gpt4all
MyToolButton {
property string name
width: 24
height: 24
imageWidth: width
imageHeight: height
ToolTip {
visible: parent.hovered
y: parent.height * 1.5
text: name
delay: Qt.styleHints.mousePressAndHoldInterval
}
Accessible.name: name
}

View File

@ -24,6 +24,12 @@ Rectangle {
property var currentChat: ChatListModel.currentChat property var currentChat: ChatListModel.currentChat
property var chatModel: currentChat.chatModel property var chatModel: currentChat.chatModel
property var currentModelInfo: currentChat && currentChat.modelInfo
property var currentModelId: null
onCurrentModelInfoChanged: {
const newId = currentModelInfo && currentModelInfo.id;
if (currentModelId !== newId) { currentModelId = newId; }
}
signal addCollectionViewRequested() signal addCollectionViewRequested()
signal addModelViewRequested() signal addModelViewRequested()
@ -79,14 +85,11 @@ Rectangle {
function open_(msg) { message = msg; open(); } function open_(msg) { message = msg; open(); }
} }
SwitchModelDialog { ConfirmationDialog {
id: switchModelDialog id: switchModelDialog
anchors.centerIn: parent property int index: -1
Item { dialogTitle: qsTr("Erase conversation?")
Accessible.role: Accessible.Dialog description: qsTr("Changing the model will erase the current conversation.")
Accessible.name: qsTr("Switch model dialog")
Accessible.description: qsTr("Warn the user if they switch models, then context will be erased")
}
} }
PopupDialog { PopupDialog {
@ -103,6 +106,16 @@ Rectangle {
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
} }
ConfirmationDialog {
id: resetContextDialog
dialogTitle: qsTr("Erase conversation?")
description: qsTr("The entire chat will be erased.")
onAccepted: {
Network.trackChatEvent("reset_context", { "length": chatModel.count });
currentChat.reset();
}
}
function getConversation() { function getConversation() {
var conversation = ""; var conversation = "";
for (var i = 0; i < chatModel.count; i++) { for (var i = 0; i < chatModel.count; i++) {
@ -703,7 +716,7 @@ Rectangle {
if (i !== -1) { if (i !== -1) {
defaultModel = comboBox.valueAt(i); defaultModel = comboBox.valueAt(i);
} else { } else {
defaultModel = comboBox.valueAt(0); defaultModel = comboBox.count ? comboBox.valueAt(0) : "";
} }
if (defaultModel !== "") { if (defaultModel !== "") {
defaultModelName = ModelList.modelInfo(defaultModel).name; defaultModelName = ModelList.modelInfo(defaultModel).name;
@ -790,9 +803,9 @@ Rectangle {
Layout.leftMargin: 50 Layout.leftMargin: 50
Layout.rightMargin: 50 Layout.rightMargin: 50
Layout.alignment: Qt.AlignHCenter Layout.alignment: Qt.AlignHCenter
spacing: 25 spacing: 10
model: chatModel model: chatModel
cacheBuffer: Math.max(0, listView.contentHeight) cacheBuffer: 2147483647
ScrollBar.vertical: ScrollBar { ScrollBar.vertical: ScrollBar {
policy: ScrollBar.AsNeeded policy: ScrollBar.AsNeeded
@ -804,6 +817,12 @@ Rectangle {
delegate: ChatItemView { delegate: ChatItemView {
width: listView.contentItem.width - 15 width: listView.contentItem.width - 15
inputBoxText: textInput.text
onSetInputBoxText: text => {
textInput.text = text;
textInput.forceActiveFocus();
textInput.cursorPosition = text.length;
}
} }
function scrollToEnd() { function scrollToEnd() {
@ -832,11 +851,9 @@ Rectangle {
clip: true clip: true
z: 400 z: 400
property bool isHovered: { property bool isHovered: (
return conversationTrayButton.isHovered || conversationTrayButton.isHovered || resetContextButton.hovered || copyChatButton.hovered
resetContextButton.hovered || copyChatButton.hovered || )
regenerateButton.hovered
}
state: conversationTrayContent.isHovered ? "expanded" : "collapsed" state: conversationTrayContent.isHovered ? "expanded" : "collapsed"
states: [ states: [
@ -892,11 +909,7 @@ Rectangle {
source: "qrc:/gpt4all/icons/recycle.svg" source: "qrc:/gpt4all/icons/recycle.svg"
imageWidth: 20 imageWidth: 20
imageHeight: 20 imageHeight: 20
onClicked: { onClicked: resetContextDialog.open()
Network.trackChatEvent("reset_context", { "length": chatModel.count })
currentChat.reset();
currentChat.processSystemPrompt();
}
ToolTip.visible: resetContextButton.hovered ToolTip.visible: resetContextButton.hovered
ToolTip.text: qsTr("Erase and reset chat session") ToolTip.text: qsTr("Erase and reset chat session")
} }
@ -921,34 +934,6 @@ Rectangle {
ToolTip.visible: copyChatButton.hovered ToolTip.visible: copyChatButton.hovered
ToolTip.text: qsTr("Copy chat session to clipboard") ToolTip.text: qsTr("Copy chat session to clipboard")
} }
MyToolButton {
id: regenerateButton
Layout.preferredWidth: 40
Layout.preferredHeight: 40
source: "qrc:/gpt4all/icons/regenerate.svg"
imageWidth: 20
imageHeight: 20
visible: chatModel.count && !currentChat.isServer && currentChat.isModelLoaded && !currentChat.responseInProgress
onClicked: {
if (chatModel.count < 2)
return
var promptIndex = chatModel.count - 2
var promptElement = chatModel.get(promptIndex)
var responseIndex = chatModel.count - 1
var responseElement = chatModel.get(responseIndex)
if (promptElement.name !== "Prompt: " || responseElement.name !== "Response: ")
return
currentChat.regenerateResponse()
chatModel.updateCurrentResponse(responseIndex, true)
chatModel.updateStopped(responseIndex, false)
chatModel.updateThumbsUpState(responseIndex, false)
chatModel.updateThumbsDownState(responseIndex, false)
chatModel.updateNewResponse(responseIndex, "")
currentChat.prompt(promptElement.promptPlusAttachments)
}
ToolTip.visible: regenerateButton.hovered
ToolTip.text: qsTr("Redo last chat response")
}
} }
} }
@ -1026,13 +1011,15 @@ Rectangle {
anchors.leftMargin: 30 anchors.leftMargin: 30
horizontalAlignment: Qt.AlignRight horizontalAlignment: Qt.AlignRight
verticalAlignment: Qt.AlignVCenter verticalAlignment: Qt.AlignVCenter
color: theme.mutedTextColor color: textInputView.error !== null ? theme.textErrorColor : theme.mutedTextColor
visible: currentChat.tokenSpeed !== "" || externalHoveredLink !== "" visible: currentChat.tokenSpeed !== "" || externalHoveredLink !== "" || textInputView.error !== null
elide: Text.ElideRight elide: Text.ElideRight
wrapMode: Text.WordWrap wrapMode: Text.WordWrap
text: { text: {
if (externalHoveredLink !== "") if (externalHoveredLink !== "")
return externalHoveredLink return externalHoveredLink
if (textInputView.error !== null)
return textInputView.error;
const segments = [currentChat.tokenSpeed]; const segments = [currentChat.tokenSpeed];
const device = currentChat.device; const device = currentChat.device;
@ -1050,6 +1037,7 @@ Rectangle {
} }
font.pixelSize: theme.fontSizeSmaller font.pixelSize: theme.fontSizeSmaller
font.bold: true font.bold: true
onLinkActivated: function(link) { Qt.openUrlExternally(link) }
} }
RectangularGlow { RectangularGlow {
@ -1079,8 +1067,8 @@ Rectangle {
Rectangle { Rectangle {
id: textInputView id: textInputView
color: theme.controlBackground color: theme.controlBackground
border.width: 1 border.width: error === null ? 1 : 2
border.color: theme.controlBorder border.color: error === null ? theme.controlBorder : theme.textErrorColor
radius: 10 radius: 10
anchors.left: parent.left anchors.left: parent.left
anchors.right: parent.right anchors.right: parent.right
@ -1091,6 +1079,41 @@ Rectangle {
height: textInputViewLayout.implicitHeight height: textInputViewLayout.implicitHeight
visible: !currentChat.isServer && ModelList.selectableModels.count !== 0 visible: !currentChat.isServer && ModelList.selectableModels.count !== 0
property var error: null
function checkError() {
const info = currentModelInfo;
if (info === null || !info.id) {
error = null;
} else if (info.chatTemplate.isLegacy) {
error = qsTr("Legacy prompt template needs to be " +
"<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">updated" +
"</a> in Settings.");
} else if (!info.chatTemplate.isSet) {
error = qsTr("No <a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
"chat template</a> configured.");
} else if (/^\s*$/.test(info.chatTemplate.value)) {
error = qsTr("The <a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
"chat template</a> cannot be blank.");
} else if (info.systemMessage.isLegacy) {
error = qsTr("Legacy system prompt needs to be " +
"<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">updated" +
"</a> in Settings.");
} else
error = null;
}
Component.onCompleted: checkError()
Connections {
target: window
function onCurrentModelIdChanged() { textInputView.checkError(); }
}
Connections {
target: MySettings
function onChatTemplateChanged(info)
{ if (info.id === window.currentModelId) textInputView.checkError(); }
function onSystemMessageChanged(info)
{ if (info.id === window.currentModelId) textInputView.checkError(); }
}
MouseArea { MouseArea {
id: textInputViewMouseArea id: textInputViewMouseArea
anchors.fill: parent anchors.fill: parent
@ -1214,16 +1237,16 @@ Rectangle {
Accessible.role: Accessible.EditableText Accessible.role: Accessible.EditableText
Accessible.name: placeholderText Accessible.name: placeholderText
Accessible.description: qsTr("Send messages/prompts to the model") Accessible.description: qsTr("Send messages/prompts to the model")
Keys.onReturnPressed: (event)=> { Keys.onReturnPressed: event => {
if (event.modifiers & Qt.ControlModifier || event.modifiers & Qt.ShiftModifier) if (event.modifiers & Qt.ControlModifier || event.modifiers & Qt.ShiftModifier) {
event.accepted = false; event.accepted = false;
else { } else if (!chatModel.hasError && textInputView.error === null) {
editingFinished(); editingFinished();
sendMessage() sendMessage();
} }
} }
function sendMessage() { function sendMessage() {
if ((textInput.text === "" && attachmentModel.count === 0) || currentChat.responseInProgress || currentChat.restoringFromText) if ((textInput.text === "" && attachmentModel.count === 0) || currentChat.responseInProgress)
return return
currentChat.stopGenerating() currentChat.stopGenerating()
@ -1338,6 +1361,7 @@ Rectangle {
imageWidth: theme.fontSizeLargest imageWidth: theme.fontSizeLargest
imageHeight: theme.fontSizeLargest imageHeight: theme.fontSizeLargest
visible: !currentChat.responseInProgress && !currentChat.isServer && ModelList.selectableModels.count !== 0 visible: !currentChat.responseInProgress && !currentChat.isServer && ModelList.selectableModels.count !== 0
enabled: !chatModel.hasError && textInputView.error === null
source: "qrc:/gpt4all/icons/send_message.svg" source: "qrc:/gpt4all/icons/send_message.svg"
Accessible.name: qsTr("Send message") Accessible.name: qsTr("Send message")
Accessible.description: qsTr("Sends the message/prompt contained in textfield to the model") Accessible.description: qsTr("Sends the message/prompt contained in textfield to the model")

View File

@ -0,0 +1,59 @@
import QtCore
import QtQuick
import QtQuick.Controls
import QtQuick.Controls.Basic
import QtQuick.Layouts
MyDialog {
id: confirmationDialog
anchors.centerIn: parent
modal: true
padding: 20
property alias dialogTitle: titleText.text
property alias description: descriptionText.text
Theme { id: theme }
contentItem: ColumnLayout {
Text {
id: titleText
Layout.alignment: Qt.AlignHCenter
textFormat: Text.StyledText
color: theme.textColor
font.pixelSize: theme.fontSizeLarger
font.bold: true
}
Text {
id: descriptionText
Layout.alignment: Qt.AlignHCenter
textFormat: Text.StyledText
color: theme.textColor
font.pixelSize: theme.fontSizeMedium
}
}
footer: DialogButtonBox {
id: dialogBox
padding: 20
alignment: Qt.AlignRight
spacing: 10
MySettingsButton {
text: qsTr("OK")
textColor: theme.mediumButtonText
backgroundColor: theme.mediumButtonBackground
backgroundColorHovered: theme.mediumButtonBackgroundHovered
DialogButtonBox.buttonRole: DialogButtonBox.AcceptRole
}
MySettingsButton {
text: qsTr("Cancel")
DialogButtonBox.buttonRole: DialogButtonBox.RejectRole
}
background: Rectangle {
color: "transparent"
}
Keys.onEnterPressed: confirmationDialog.accept()
Keys.onReturnPressed: confirmationDialog.accept()
}
Component.onCompleted: dialogBox.forceActiveFocus()
}

View File

@ -10,7 +10,7 @@ import mysettings
import network import network
MySettingsTab { MySettingsTab {
onRestoreDefaultsClicked: { onRestoreDefaults: {
MySettings.restoreLocalDocsDefaults(); MySettings.restoreLocalDocsDefaults();
} }

View File

@ -8,10 +8,34 @@ import mysettings
import chatlistmodel import chatlistmodel
MySettingsTab { MySettingsTab {
onRestoreDefaultsClicked: { onRestoreDefaults: {
MySettings.restoreModelDefaults(root.currentModelInfo); MySettings.restoreModelDefaults(root.currentModelInfo);
} }
title: qsTr("Model") title: qsTr("Model")
ConfirmationDialog {
id: resetSystemMessageDialog
property var index: null
property bool resetClears: false
dialogTitle: qsTr("%1 system message?").arg(resetClears ? qsTr("Clear") : qsTr("Reset"))
description: qsTr("The system message will be %1.").arg(resetClears ? qsTr("removed") : qsTr("reset to the default"))
onAccepted: MySettings.resetModelSystemMessage(ModelList.modelInfo(index))
function show(index_, resetClears_) { index = index_; resetClears = resetClears_; open(); }
}
ConfirmationDialog {
id: resetChatTemplateDialog
property bool resetClears: false
property var index: null
dialogTitle: qsTr("%1 chat template?").arg(resetClears ? qsTr("Clear") : qsTr("Reset"))
description: qsTr("The chat template will be %1.").arg(resetClears ? qsTr("erased") : qsTr("reset to the default"))
onAccepted: {
MySettings.resetModelChatTemplate(ModelList.modelInfo(index));
templateTextArea.resetText();
}
function show(index_, resetClears_) { index = index_; resetClears = resetClears_; open(); }
}
contentItem: GridLayout { contentItem: GridLayout {
id: root id: root
columns: 3 columns: 3
@ -35,6 +59,7 @@ MySettingsTab {
RowLayout { RowLayout {
Layout.fillWidth: true Layout.fillWidth: true
Layout.maximumWidth: parent.width
Layout.row: 2 Layout.row: 2
Layout.column: 0 Layout.column: 0
Layout.columnSpan: 2 Layout.columnSpan: 2
@ -153,69 +178,154 @@ MySettingsTab {
Layout.fillWidth: true Layout.fillWidth: true
} }
MySettingsLabel { RowLayout {
visible: !root.currentModelInfo.isOnline
text: qsTr("System Prompt")
helpText: qsTr("Prefixed at the beginning of every conversation. Must contain the appropriate framing tokens.")
Layout.row: 7 Layout.row: 7
Layout.column: 0 Layout.columnSpan: 2
Layout.topMargin: 15 Layout.topMargin: 15
Layout.fillWidth: true
Layout.maximumWidth: parent.width
spacing: 10
MySettingsLabel {
id: systemMessageLabel
text: qsTr("System Message")
helpText: qsTr("A message to set the context or guide the behavior of the model. Leave blank for " +
"none. NOTE: Since GPT4All 3.5, this should not contain control tokens.")
onReset: () => resetSystemMessageDialog.show(root.currentModelId, resetClears)
function updateResetButton() {
const info = root.currentModelInfo;
// NOTE: checks if the *override* is set, regardless of whether there is a default
canReset = !!info.id && MySettings.isModelSystemMessageSet(info);
resetClears = !info.defaultSystemMessage;
}
Component.onCompleted: updateResetButton()
Connections {
target: root
function onCurrentModelIdChanged() { systemMessageLabel.updateResetButton(); }
}
Connections {
target: MySettings
function onSystemMessageChanged(info)
{ if (info.id === root.currentModelId) systemMessageLabel.updateResetButton(); }
}
}
Label {
id: systemMessageLabelHelp
visible: systemMessageArea.errState !== "ok"
Layout.alignment: Qt.AlignBottom
Layout.fillWidth: true
Layout.rightMargin: 5
Layout.maximumHeight: systemMessageLabel.height
text: qsTr("System message is not " +
"<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">plain text</a>.")
color: systemMessageArea.errState === "error" ? theme.textErrorColor : theme.textWarningColor
font.pixelSize: theme.fontSizeLarger
font.bold: true
wrapMode: Text.Wrap
elide: Text.ElideRight
onLinkActivated: function(link) { Qt.openUrlExternally(link) }
}
} }
Rectangle { Rectangle {
id: systemPrompt id: systemMessage
visible: !root.currentModelInfo.isOnline
Layout.row: 8 Layout.row: 8
Layout.column: 0 Layout.column: 0
Layout.columnSpan: 2 Layout.columnSpan: 2
Layout.fillWidth: true Layout.fillWidth: true
color: "transparent" color: "transparent"
Layout.minimumHeight: Math.max(100, systemPromptArea.contentHeight + 20) Layout.minimumHeight: Math.max(100, systemMessageArea.contentHeight + 20)
MyTextArea { MyTextArea {
id: systemPromptArea id: systemMessageArea
anchors.fill: parent anchors.fill: parent
text: root.currentModelInfo.systemPrompt property bool isBeingReset: false
function resetText() {
const info = root.currentModelInfo;
isBeingReset = true;
text = (info.id ? info.systemMessage.value : null) ?? "";
isBeingReset = false;
}
Component.onCompleted: resetText()
Connections { Connections {
target: MySettings target: MySettings
function onSystemPromptChanged() { function onSystemMessageChanged(info)
systemPromptArea.text = root.currentModelInfo.systemPrompt; { if (info.id === root.currentModelId) systemMessageArea.resetText(); }
}
} }
Connections { Connections {
target: root target: root
function onCurrentModelInfoChanged() { function onCurrentModelIdChanged() { systemMessageArea.resetText(); }
systemPromptArea.text = root.currentModelInfo.systemPrompt;
}
} }
// strict validation, because setModelSystemMessage clears isLegacy
readonly property var reLegacyCheck: (
/(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>/m
)
onTextChanged: { onTextChanged: {
MySettings.setModelSystemPrompt(root.currentModelInfo, text) const info = root.currentModelInfo;
if (!info.id) {
errState = "ok";
} else if (info.systemMessage.isLegacy && (isBeingReset || reLegacyCheck.test(text))) {
errState = "error";
} else
errState = reLegacyCheck.test(text) ? "warning" : "ok";
if (info.id && errState !== "error" && !isBeingReset)
MySettings.setModelSystemMessage(info, text);
systemMessageLabel.updateResetButton();
} }
Accessible.role: Accessible.EditableText Accessible.role: Accessible.EditableText
Accessible.name: systemMessageLabel.text
Accessible.description: systemMessageLabelHelp.text
} }
} }
RowLayout { RowLayout {
Layout.row: 9 Layout.row: 9
Layout.column: 0
Layout.columnSpan: 2 Layout.columnSpan: 2
Layout.topMargin: 15 Layout.topMargin: 15
Layout.fillWidth: true
Layout.maximumWidth: parent.width
spacing: 10 spacing: 10
MySettingsLabel { MySettingsLabel {
id: promptTemplateLabel id: chatTemplateLabel
text: qsTr("Prompt Template") text: qsTr("Chat Template")
helpText: qsTr("The template that wraps every prompt.") helpText: qsTr("This Jinja template turns the chat into input for the model.")
onReset: () => resetChatTemplateDialog.show(root.currentModelId, resetClears)
function updateResetButton() {
const info = root.currentModelInfo;
canReset = !!info.id && (
MySettings.isModelChatTemplateSet(info)
|| templateTextArea.text !== (info.chatTemplate.value ?? "")
);
resetClears = !info.defaultChatTemplate;
}
Component.onCompleted: updateResetButton()
Connections {
target: root
function onCurrentModelIdChanged() { chatTemplateLabel.updateResetButton(); }
}
Connections {
target: MySettings
function onChatTemplateChanged(info)
{ if (info.id === root.currentModelId) chatTemplateLabel.updateResetButton(); }
}
} }
MySettingsLabel { Label {
id: promptTemplateLabelHelp id: chatTemplateLabelHelp
text: qsTr("Must contain the string \"%1\" to be replaced with the user's input.") visible: templateTextArea.errState !== "ok"
color: theme.textErrorColor Layout.alignment: Qt.AlignBottom
visible: templateTextArea.text.indexOf("%1") === -1 Layout.fillWidth: true
wrapMode: TextArea.Wrap Layout.rightMargin: 5
Layout.maximumHeight: chatTemplateLabel.height
text: templateTextArea.errMsg
color: templateTextArea.errState === "error" ? theme.textErrorColor : theme.textWarningColor
font.pixelSize: theme.fontSizeLarger
font.bold: true
wrapMode: Text.Wrap
elide: Text.ElideRight
onLinkActivated: function(link) { Qt.openUrlExternally(link) }
} }
} }
Rectangle { Rectangle {
id: promptTemplate id: chatTemplate
Layout.row: 10 Layout.row: 10
Layout.column: 0 Layout.column: 0
Layout.columnSpan: 2 Layout.columnSpan: 2
@ -226,27 +336,71 @@ MySettingsTab {
MyTextArea { MyTextArea {
id: templateTextArea id: templateTextArea
anchors.fill: parent anchors.fill: parent
text: root.currentModelInfo.promptTemplate font: fixedFont
property bool isBeingReset: false
property var errMsg: null
function resetText() {
const info = root.currentModelInfo;
isBeingReset = true;
text = (info.id ? info.chatTemplate.value : null) ?? "";
isBeingReset = false;
}
Component.onCompleted: resetText()
Connections { Connections {
target: MySettings target: MySettings
function onPromptTemplateChanged() { function onChatTemplateChanged() { templateTextArea.resetText(); }
templateTextArea.text = root.currentModelInfo.promptTemplate;
}
} }
Connections { Connections {
target: root target: root
function onCurrentModelInfoChanged() { function onCurrentModelIdChanged() { templateTextArea.resetText(); }
templateTextArea.text = root.currentModelInfo.promptTemplate; }
} function legacyCheck() {
return /%[12]\b/.test(text) || !/\{%.*%\}.*\{\{.*\}\}.*\{%.*%\}/.test(text.replace(/\n/g, ''))
|| !/\bcontent\b/.test(text);
} }
onTextChanged: { onTextChanged: {
if (templateTextArea.text.indexOf("%1") !== -1) { const info = root.currentModelInfo;
MySettings.setModelPromptTemplate(root.currentModelInfo, text) let jinjaError;
if (!info.id) {
errMsg = null;
errState = "ok";
} else if (info.chatTemplate.isLegacy && (isBeingReset || legacyCheck())) {
errMsg = null;
errState = "error";
} else if (text === "" && !info.chatTemplate.isSet) {
errMsg = qsTr("No <a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
"chat template</a> configured.");
errState = "error";
} else if (/^\s*$/.test(text)) {
errMsg = qsTr("The <a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
"chat template</a> cannot be blank.");
errState = "error";
} else if ((jinjaError = MySettings.checkJinjaTemplateError(text)) !== null) {
errMsg = qsTr("<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">Syntax" +
" error</a>: %1").arg(jinjaError);
errState = "error";
} else if (legacyCheck()) {
errMsg = qsTr("Chat template is not in " +
"<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
"Jinja format</a>.")
errState = "warning";
} else {
errState = "ok";
}
if (info.id && errState !== "error" && !isBeingReset)
MySettings.setModelChatTemplate(info, text);
chatTemplateLabel.updateResetButton();
}
Keys.onPressed: event => {
if (event.key === Qt.Key_Tab) {
const a = templateTextArea;
event.accepted = true; // suppress tab
a.insert(a.cursorPosition, ' '); // four spaces
} }
} }
Accessible.role: Accessible.EditableText Accessible.role: Accessible.EditableText
Accessible.name: promptTemplateLabel.text Accessible.name: chatTemplateLabel.text
Accessible.description: promptTemplateLabelHelp.text Accessible.description: chatTemplateLabelHelp.text
} }
} }

View File

@ -17,6 +17,7 @@ Button {
property color borderColor: "transparent" property color borderColor: "transparent"
property real fontPixelSize: theme.fontSizeLarge property real fontPixelSize: theme.fontSizeLarge
property string toolTip property string toolTip
property alias backgroundRadius: background.radius
contentItem: Text { contentItem: Text {
text: myButton.text text: myButton.text
@ -28,6 +29,7 @@ Button {
Accessible.name: text Accessible.name: text
} }
background: Rectangle { background: Rectangle {
id: background
radius: 10 radius: 10
border.width: borderWidth border.width: borderWidth
border.color: borderColor border.color: borderColor

View File

@ -17,13 +17,42 @@ ColumnLayout {
property alias color: mainTextLabel.color property alias color: mainTextLabel.color
property alias linkColor: mainTextLabel.linkColor property alias linkColor: mainTextLabel.linkColor
Label { property var onReset: null
id: mainTextLabel property alias canReset: resetButton.enabled
color: theme.settingsTitleTextColor property bool resetClears: false
font.pixelSize: theme.fontSizeLarger
font.bold: true Item {
onLinkActivated: function(link) { anchors.margins: 5
root.linkActivated(link); width: childrenRect.width
height: mainTextLabel.contentHeight
Label {
id: mainTextLabel
anchors.left: parent.left
anchors.top: parent.top
anchors.bottom: parent.bottom
color: theme.settingsTitleTextColor
font.pixelSize: theme.fontSizeLarger
font.bold: true
verticalAlignment: Text.AlignVCenter
onLinkActivated: function(link) {
root.linkActivated(link);
}
}
MySettingsButton {
id: resetButton
anchors.baseline: mainTextLabel.baseline
anchors.left: mainTextLabel.right
height: mainTextLabel.contentHeight
anchors.leftMargin: 10
padding: 2
leftPadding: 10
rightPadding: 10
backgroundRadius: 5
text: resetClears ? qsTr("Clear") : qsTr("Reset")
visible: root.onReset !== null
onClicked: root.onReset()
} }
} }
Label { Label {

View File

@ -9,7 +9,7 @@ Item {
property string title: "" property string title: ""
property Item contentItem: null property Item contentItem: null
property bool showRestoreDefaultsButton: true property bool showRestoreDefaultsButton: true
signal restoreDefaultsClicked signal restoreDefaults
onContentItemChanged: function() { onContentItemChanged: function() {
if (contentItem) { if (contentItem) {
@ -19,6 +19,13 @@ Item {
} }
} }
ConfirmationDialog {
id: restoreDefaultsDialog
dialogTitle: qsTr("Restore defaults?")
description: qsTr("This page of settings will be reset to the defaults.")
onAccepted: root.restoreDefaults()
}
ScrollView { ScrollView {
id: scrollView id: scrollView
width: parent.width width: parent.width
@ -47,6 +54,7 @@ Item {
Column { Column {
id: contentInner id: contentInner
Layout.fillWidth: true Layout.fillWidth: true
Layout.maximumWidth: parent.width
} }
Item { Item {
@ -63,9 +71,7 @@ Item {
Accessible.role: Accessible.Button Accessible.role: Accessible.Button
Accessible.name: text Accessible.name: text
Accessible.description: qsTr("Restores settings dialog to a default state") Accessible.description: qsTr("Restores settings dialog to a default state")
onClicked: { onClicked: restoreDefaultsDialog.open()
root.restoreDefaultsClicked();
}
} }
} }
} }

View File

@ -5,18 +5,27 @@ import QtQuick.Controls.Basic
TextArea { TextArea {
id: myTextArea id: myTextArea
property string errState: "ok" // one of "ok", "error", "warning"
color: enabled ? theme.textColor : theme.mutedTextColor color: enabled ? theme.textColor : theme.mutedTextColor
placeholderTextColor: theme.mutedTextColor placeholderTextColor: theme.mutedTextColor
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
background: Rectangle { background: Rectangle {
implicitWidth: 150 implicitWidth: 150
color: theme.controlBackground color: theme.controlBackground
border.width: 1 border.width: errState === "ok" ? 1 : 2
border.color: theme.controlBorder border.color: {
switch (errState) {
case "ok": return theme.controlBorder;
case "warning": return theme.textWarningColor;
case "error": return theme.textErrorColor;
}
}
radius: 10 radius: 10
} }
padding: 10 padding: 10
wrapMode: TextArea.Wrap wrapMode: TextArea.Wrap
ToolTip.delay: Qt.styleHints.mousePressAndHoldInterval ToolTip.delay: Qt.styleHints.mousePressAndHoldInterval
} }

View File

@ -16,6 +16,7 @@ Button {
property alias fillMode: image.fillMode property alias fillMode: image.fillMode
property alias imageWidth: image.sourceSize.width property alias imageWidth: image.sourceSize.width
property alias imageHeight: image.sourceSize.height property alias imageHeight: image.sourceSize.height
property alias bgTransform: background.transform
contentItem: Text { contentItem: Text {
text: myButton.text text: myButton.text
horizontalAlignment: Text.AlignHCenter horizontalAlignment: Text.AlignHCenter
@ -26,6 +27,7 @@ Button {
} }
background: Item { background: Item {
id: background
anchors.fill: parent anchors.fill: parent
Rectangle { Rectangle {
anchors.fill: parent anchors.fill: parent

View File

@ -1,46 +0,0 @@
import QtCore
import QtQuick
import QtQuick.Controls
import QtQuick.Controls.Basic
import QtQuick.Layouts
import llm
import mysettings
MyDialog {
id: switchModelDialog
anchors.centerIn: parent
modal: true
padding: 20
property int index: -1
Theme {
id: theme
}
contentItem: Text {
textFormat: Text.StyledText
text: qsTr("<b>Warning:</b> changing the model will erase the current conversation. Do you wish to continue?")
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
}
footer: DialogButtonBox {
id: dialogBox
padding: 20
alignment: Qt.AlignRight
spacing: 10
MySettingsButton {
text: qsTr("Continue")
Accessible.description: qsTr("Continue with model loading")
DialogButtonBox.buttonRole: DialogButtonBox.AcceptRole
}
MySettingsButton {
text: qsTr("Cancel")
Accessible.description: qsTr("Cancel")
DialogButtonBox.buttonRole: DialogButtonBox.RejectRole
}
background: Rectangle {
color: "transparent"
}
}
}

View File

@ -64,6 +64,9 @@ QtObject {
property color green800: Qt.hsla(123/360, 0.17, 0.24) property color green800: Qt.hsla(123/360, 0.17, 0.24)
property color green900: Qt.hsla(124/360, 0.17, 0.20) property color green900: Qt.hsla(124/360, 0.17, 0.20)
property color green950: Qt.hsla(125/360, 0.22, 0.10) property color green950: Qt.hsla(125/360, 0.22, 0.10)
property color green300_sat: Qt.hsla(122/360, 0.24, 0.73)
property color green400_sat: Qt.hsla(122/360, 0.23, 0.58)
property color green450_sat: Qt.hsla(122/360, 0.23, 0.52)
// yellow // yellow
property color yellow0: Qt.hsla(47/360, 0.90, 0.99) property color yellow0: Qt.hsla(47/360, 0.90, 0.99)
@ -99,6 +102,7 @@ QtObject {
property color purple200: Qt.hsla(279/360, 1.0, 0.91) property color purple200: Qt.hsla(279/360, 1.0, 0.91)
property color purple300: Qt.hsla(279/360, 1.0, 0.84) property color purple300: Qt.hsla(279/360, 1.0, 0.84)
property color purple400: Qt.hsla(279/360, 1.0, 0.73) property color purple400: Qt.hsla(279/360, 1.0, 0.73)
property color purple450: Qt.hsla(279/360, 1.0, 0.68)
property color purple500: Qt.hsla(279/360, 1.0, 0.63) property color purple500: Qt.hsla(279/360, 1.0, 0.63)
property color purple600: Qt.hsla(279/360, 1.0, 0.53) property color purple600: Qt.hsla(279/360, 1.0, 0.53)
property color purple700: Qt.hsla(279/360, 1.0, 0.47) property color purple700: Qt.hsla(279/360, 1.0, 0.47)
@ -408,6 +412,39 @@ QtObject {
} }
} }
property color mediumButtonBackground: {
switch (MySettings.chatTheme) {
case MySettingsEnums.ChatTheme.LegacyDark:
return purple400
case MySettingsEnums.ChatTheme.Dark:
return green400_sat
default:
return green400_sat
}
}
property color mediumButtonBackgroundHovered: {
switch (MySettings.chatTheme) {
case MySettingsEnums.ChatTheme.LegacyDark:
return purple450
case MySettingsEnums.ChatTheme.Dark:
return green450_sat
default:
return green300_sat
}
}
property color mediumButtonText: {
switch (MySettings.chatTheme) {
case MySettingsEnums.ChatTheme.LegacyDark:
return textColor
case MySettingsEnums.ChatTheme.Dark:
return textColor
default:
return white
}
}
property color darkButtonText: { property color darkButtonText: {
switch (MySettings.chatTheme) { switch (MySettings.chatTheme) {
case MySettingsEnums.ChatTheme.LegacyDark: case MySettingsEnums.ChatTheme.LegacyDark:
@ -922,16 +959,8 @@ QtObject {
} }
} }
property color textErrorColor: { readonly property color textErrorColor: red400
switch (MySettings.chatTheme) { readonly property color textWarningColor: yellow400
case MySettingsEnums.ChatTheme.LegacyDark:
return red400
case MySettingsEnums.ChatTheme.Dark:
return red400
default:
return red400
}
}
property color settingsTitleTextColor: { property color settingsTitleTextColor: {
switch (MySettings.chatTheme) { switch (MySettings.chatTheme) {

View File

@ -1,7 +1,6 @@
#include "chat.h" #include "chat.h"
#include "chatlistmodel.h" #include "chatlistmodel.h"
#include "mysettings.h"
#include "network.h" #include "network.h"
#include "server.h" #include "server.h"
@ -11,7 +10,6 @@
#include <QLatin1String> #include <QLatin1String>
#include <QMap> #include <QMap>
#include <QString> #include <QString>
#include <QStringList>
#include <QVariant> #include <QVariant>
#include <Qt> #include <Qt>
#include <QtLogging> #include <QtLogging>
@ -56,18 +54,18 @@ void Chat::connectLLM()
// Should be in different threads // Should be in different threads
connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseFailed, this, &Chat::handleResponseFailed, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
@ -75,11 +73,10 @@ void Chat::connectLLM()
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection); connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::QueuedConnection); connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::QueuedConnection);
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::QueuedConnection);
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection);
connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection);
connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections); connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections);
connect(ModelList::globalInstance(), &ModelList::modelInfoChanged, this, &Chat::handleModelInfoChanged);
} }
void Chat::reset() void Chat::reset()
@ -87,28 +84,17 @@ void Chat::reset()
stopGenerating(); stopGenerating();
// Erase our current on disk representation as we're completely resetting the chat along with id // Erase our current on disk representation as we're completely resetting the chat along with id
ChatListModel::globalInstance()->removeChatFile(this); ChatListModel::globalInstance()->removeChatFile(this);
emit resetContextRequested();
m_id = Network::globalInstance()->generateUniqueId(); m_id = Network::globalInstance()->generateUniqueId();
emit idChanged(m_id); emit idChanged(m_id);
// NOTE: We deliberately do no reset the name or creation date to indicate that this was originally // NOTE: We deliberately do no reset the name or creation date to indicate that this was originally
// an older chat that was reset for another purpose. Resetting this data will lead to the chat // an older chat that was reset for another purpose. Resetting this data will lead to the chat
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat' // name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
// further down in the list. This might surprise the user. In the future, we might get rid of // further down in the list. This might surprise the user. In the future, we might get rid of
// the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown // the "reset context" button in the UI.
// we effectively do a reset context. We *have* to do this right now when switching between different
// types of models. The only way to get rid of that would be a very long recalculate where we rebuild
// the context if we switch between different types of models. Probably the right way to fix this
// is to allow switching models but throwing up a dialog warning users if we switch between types
// of models that a long recalculation will ensue.
m_chatModel->clear(); m_chatModel->clear();
m_needsSave = true; m_needsSave = true;
} }
void Chat::processSystemPrompt()
{
emit processSystemPromptRequested();
}
void Chat::resetResponseState() void Chat::resetResponseState()
{ {
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval) if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
@ -160,25 +146,30 @@ void Chat::newPromptResponsePair(const QString &prompt, const QList<QUrl> &attac
if (!attachedContexts.isEmpty()) if (!attachedContexts.isEmpty())
promptPlusAttached = attachedContexts.join("\n\n") + "\n\n" + prompt; promptPlusAttached = attachedContexts.join("\n\n") + "\n\n" + prompt;
newPromptResponsePairInternal(prompt, attachments); resetResponseState();
emit resetResponseRequested(); qsizetype prevMsgIndex = m_chatModel->count() - 1;
if (prevMsgIndex >= 0)
m_chatModel->updateCurrentResponse(prevMsgIndex, false);
m_chatModel->appendPrompt(prompt, attachments);
m_chatModel->appendResponse(prevMsgIndex + 1);
this->prompt(promptPlusAttached); emit promptRequested(m_collections);
m_needsSave = true;
} }
void Chat::prompt(const QString &prompt) void Chat::regenerateResponse(int index)
{ {
resetResponseState(); resetResponseState();
emit promptRequested(m_collections, prompt); emit regenerateResponseRequested(index);
m_needsSave = true; m_needsSave = true;
} }
void Chat::regenerateResponse() QVariant Chat::popPrompt(int index)
{ {
const int index = m_chatModel->count() - 1; auto content = m_llmodel->popPrompt(index);
m_chatModel->updateSources(index, QList<ResultInfo>());
emit regenerateResponseRequested();
m_needsSave = true; m_needsSave = true;
if (content) return *content;
return QVariant::fromValue(nullptr);
} }
void Chat::stopGenerating() void Chat::stopGenerating()
@ -202,6 +193,14 @@ void Chat::handleResponseChanged(const QString &response)
m_chatModel->updateValue(index, response); m_chatModel->updateValue(index, response);
} }
void Chat::handleResponseFailed(const QString &error)
{
const int index = m_chatModel->count() - 1;
m_chatModel->updateValue(index, error);
m_chatModel->setError();
responseStopped(0);
}
void Chat::handleModelLoadingPercentageChanged(float loadingPercentage) void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
{ {
if (m_shouldDeleteLater) if (m_shouldDeleteLater)
@ -272,25 +271,6 @@ void Chat::setModelInfo(const ModelInfo &modelInfo)
emit modelChangeRequested(modelInfo); emit modelChangeRequested(modelInfo);
} }
// the server needs to block until response is reset, so it calls resetResponse on its own m_llmThread
void Chat::serverNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments)
{
newPromptResponsePairInternal(prompt, attachments);
}
void Chat::newPromptResponsePairInternal(const QString &prompt, const QList<PromptAttachment> &attachments)
{
resetResponseState();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt("Prompt: ", prompt, attachments);
m_chatModel->appendResponse("Response: ");
}
bool Chat::restoringFromText() const
{
return m_llmodel->restoringFromText();
}
void Chat::unloadAndDeleteLater() void Chat::unloadAndDeleteLater()
{ {
if (!isModelLoaded()) { if (!isModelLoaded()) {
@ -356,12 +336,6 @@ void Chat::generatedQuestionFinished(const QString &question)
m_needsSave = true; m_needsSave = true;
} }
void Chat::handleRestoringFromText()
{
Network::globalInstance()->trackChatEvent("recalc_context", { {"length", m_chatModel->count()} });
emit restoringFromTextChanged();
}
void Chat::handleModelLoadingError(const QString &error) void Chat::handleModelLoadingError(const QString &error)
{ {
if (!error.isEmpty()) { if (!error.isEmpty()) {
@ -396,12 +370,19 @@ QString Chat::fallbackReason() const
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results) void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
{ {
m_databaseResults = results; m_databaseResults = results;
const int index = m_chatModel->count() - 1;
m_chatModel->updateSources(index, m_databaseResults);
m_needsSave = true; m_needsSave = true;
} }
// we need to notify listeners of the modelInfo property when its properties are updated,
// since it's a gadget and can't do that on its own
void Chat::handleModelInfoChanged(const ModelInfo &modelInfo) void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
{
if (!m_modelInfo.id().isNull() && modelInfo.id() == m_modelInfo.id())
emit modelInfoChanged();
}
// react if a new model is loaded
void Chat::handleModelChanged(const ModelInfo &modelInfo)
{ {
if (m_modelInfo == modelInfo) if (m_modelInfo == modelInfo)
return; return;
@ -430,10 +411,7 @@ bool Chat::serialize(QDataStream &stream, int version) const
if (version >= 3) if (version >= 3)
stream << m_collections; stream << m_collections;
const bool serializeKV = MySettings::globalInstance()->saveChatsContext(); if (!m_llmodel->serialize(stream, version))
if (version >= 6)
stream << serializeKV;
if (!m_llmodel->serialize(stream, version, serializeKV))
return false; return false;
if (!m_chatModel->serialize(stream, version)) if (!m_chatModel->serialize(stream, version))
return false; return false;
@ -462,19 +440,13 @@ bool Chat::deserialize(QDataStream &stream, int version)
if (!m_modelInfo.id().isEmpty()) if (!m_modelInfo.id().isEmpty())
emit modelInfoChanged(); emit modelInfoChanged();
bool discardKV = m_modelInfo.id().isEmpty();
if (version >= 3) { if (version >= 3) {
stream >> m_collections; stream >> m_collections;
emit collectionListChanged(m_collections); emit collectionListChanged(m_collections);
} }
bool deserializeKV = true;
if (version >= 6)
stream >> deserializeKV;
m_llmodel->setModelInfo(m_modelInfo); m_llmodel->setModelInfo(m_modelInfo);
if (!m_llmodel->deserialize(stream, version, deserializeKV, discardKV)) if (!m_llmodel->deserialize(stream, version))
return false; return false;
if (!m_chatModel->deserialize(stream, version)) if (!m_chatModel->deserialize(stream, version))
return false; return false;

View File

@ -12,6 +12,8 @@
#include <QObject> #include <QObject>
#include <QQmlEngine> #include <QQmlEngine>
#include <QString> #include <QString>
#include <QStringList> // IWYU pragma: keep
#include <QStringView>
#include <QtGlobal> #include <QtGlobal>
class QDataStream; class QDataStream;
@ -27,7 +29,6 @@ class Chat : public QObject
Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged) Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged)
Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged) Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged) Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged)
Q_PROPERTY(ResponseState responseState READ responseState NOTIFY responseStateChanged) Q_PROPERTY(ResponseState responseState READ responseState NOTIFY responseStateChanged)
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged) Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
@ -77,13 +78,12 @@ public:
bool isNewChat() const { return m_name == tr("New Chat") && !m_chatModel->count(); } bool isNewChat() const { return m_name == tr("New Chat") && !m_chatModel->count(); }
Q_INVOKABLE void reset(); Q_INVOKABLE void reset();
Q_INVOKABLE void processSystemPrompt();
bool isModelLoaded() const { return m_modelLoadingPercentage == 1.0f; } bool isModelLoaded() const { return m_modelLoadingPercentage == 1.0f; }
bool isCurrentlyLoading() const { return m_modelLoadingPercentage > 0.0f && m_modelLoadingPercentage < 1.0f; } bool isCurrentlyLoading() const { return m_modelLoadingPercentage > 0.0f && m_modelLoadingPercentage < 1.0f; }
float modelLoadingPercentage() const { return m_modelLoadingPercentage; } float modelLoadingPercentage() const { return m_modelLoadingPercentage; }
Q_INVOKABLE void newPromptResponsePair(const QString &prompt, const QList<QUrl> &attachedUrls = {}); Q_INVOKABLE void newPromptResponsePair(const QString &prompt, const QList<QUrl> &attachedUrls = {});
Q_INVOKABLE void prompt(const QString &prompt); Q_INVOKABLE void regenerateResponse(int index);
Q_INVOKABLE void regenerateResponse(); Q_INVOKABLE QVariant popPrompt(int index);
Q_INVOKABLE void stopGenerating(); Q_INVOKABLE void stopGenerating();
QList<ResultInfo> databaseResults() const { return m_databaseResults; } QList<ResultInfo> databaseResults() const { return m_databaseResults; }
@ -92,7 +92,6 @@ public:
ResponseState responseState() const; ResponseState responseState() const;
ModelInfo modelInfo() const; ModelInfo modelInfo() const;
void setModelInfo(const ModelInfo &modelInfo); void setModelInfo(const ModelInfo &modelInfo);
bool restoringFromText() const;
Q_INVOKABLE void unloadModel(); Q_INVOKABLE void unloadModel();
Q_INVOKABLE void reloadModel(); Q_INVOKABLE void reloadModel();
@ -113,7 +112,6 @@ public:
Q_INVOKABLE bool hasCollection(const QString &collection) const; Q_INVOKABLE bool hasCollection(const QString &collection) const;
Q_INVOKABLE void addCollection(const QString &collection); Q_INVOKABLE void addCollection(const QString &collection);
Q_INVOKABLE void removeCollection(const QString &collection); Q_INVOKABLE void removeCollection(const QString &collection);
void resetResponseState();
QString modelLoadingError() const { return m_modelLoadingError; } QString modelLoadingError() const { return m_modelLoadingError; }
@ -131,7 +129,7 @@ public:
void setNeedsSave(bool n) { m_needsSave = n; } void setNeedsSave(bool n) { m_needsSave = n; }
public Q_SLOTS: public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments = {}); void resetResponseState();
Q_SIGNALS: Q_SIGNALS:
void idChanged(const QString &id); void idChanged(const QString &id);
@ -143,14 +141,12 @@ Q_SIGNALS:
void modelLoadingWarning(const QString &warning); void modelLoadingWarning(const QString &warning);
void responseInProgressChanged(); void responseInProgressChanged();
void responseStateChanged(); void responseStateChanged();
void promptRequested(const QList<QString> &collectionList, const QString &prompt); void promptRequested(const QStringList &enabledCollections);
void regenerateResponseRequested(); void regenerateResponseRequested(int index);
void resetResponseRequested(); void resetResponseRequested();
void resetContextRequested(); void resetContextRequested();
void processSystemPromptRequested();
void modelChangeRequested(const ModelInfo &modelInfo); void modelChangeRequested(const ModelInfo &modelInfo);
void modelInfoChanged(); void modelInfoChanged();
void restoringFromTextChanged();
void loadDefaultModelRequested(); void loadDefaultModelRequested();
void generateNameRequested(); void generateNameRequested();
void modelLoadingErrorChanged(); void modelLoadingErrorChanged();
@ -166,22 +162,20 @@ Q_SIGNALS:
private Q_SLOTS: private Q_SLOTS:
void handleResponseChanged(const QString &response); void handleResponseChanged(const QString &response);
void handleResponseFailed(const QString &error);
void handleModelLoadingPercentageChanged(float); void handleModelLoadingPercentageChanged(float);
void promptProcessing(); void promptProcessing();
void generatingQuestions(); void generatingQuestions();
void responseStopped(qint64 promptResponseMs); void responseStopped(qint64 promptResponseMs);
void generatedNameChanged(const QString &name); void generatedNameChanged(const QString &name);
void generatedQuestionFinished(const QString &question); void generatedQuestionFinished(const QString &question);
void handleRestoringFromText();
void handleModelLoadingError(const QString &error); void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed); void handleTokenSpeedChanged(const QString &tokenSpeed);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results); void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
void handleModelInfoChanged(const ModelInfo &modelInfo); void handleModelInfoChanged(const ModelInfo &modelInfo);
void handleModelChanged(const ModelInfo &modelInfo);
void handleTrySwitchContextOfLoadedModelCompleted(int value); void handleTrySwitchContextOfLoadedModelCompleted(int value);
private:
void newPromptResponsePairInternal(const QString &prompt, const QList<PromptAttachment> &attachments);
private: private:
QString m_id; QString m_id;
QString m_name; QString m_name;

View File

@ -1,10 +1,10 @@
#include "chatapi.h" #include "chatapi.h"
#include <gpt4all-backend/llmodel.h> #include "utils.h"
#include <QCoreApplication> #include <QCoreApplication>
#include <QGuiApplication>
#include <QDebug> #include <QDebug>
#include <QGuiApplication>
#include <QJsonArray> #include <QJsonArray>
#include <QJsonDocument> #include <QJsonDocument>
#include <QJsonObject> #include <QJsonObject>
@ -13,12 +13,17 @@
#include <QNetworkRequest> #include <QNetworkRequest>
#include <QThread> #include <QThread>
#include <QUrl> #include <QUrl>
#include <QUtf8StringView>
#include <QVariant> #include <QVariant>
#include <QXmlStreamReader>
#include <Qt> #include <Qt>
#include <QtGlobal> #include <QtGlobal>
#include <QtLogging> #include <QtLogging>
#include <expected>
#include <functional>
#include <iostream> #include <iostream>
#include <utility>
using namespace Qt::Literals::StringLiterals; using namespace Qt::Literals::StringLiterals;
@ -67,71 +72,119 @@ bool ChatAPI::isModelLoaded() const
return true; return true;
} }
void ChatAPI::prompt(const std::string &prompt, static auto parsePrompt(QXmlStreamReader &xml) -> std::expected<QJsonArray, QString>
const std::string &promptTemplate, {
std::function<bool(int32_t)> promptCallback, QJsonArray messages;
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::optional<std::string_view> fakeReply) {
Q_UNUSED(promptCallback); auto xmlError = [&xml] {
Q_UNUSED(allowContextShift); return std::unexpected(u"%1:%2: %3"_s.arg(xml.lineNumber()).arg(xml.columnNumber()).arg(xml.errorString()));
Q_UNUSED(special); };
if (!isModelLoaded()) { if (xml.hasError())
std::cerr << "ChatAPI ERROR: prompt won't work with an unloaded model!\n"; return xmlError();
return; if (xml.atEnd())
return messages;
// skip header
bool foundElement = false;
do {
switch (xml.readNext()) {
using enum QXmlStreamReader::TokenType;
case Invalid:
return xmlError();
case EndDocument:
return messages;
default:
foundElement = true;
case StartDocument:
case Comment:
case DTD:
case ProcessingInstruction:
;
}
} while (!foundElement);
// document body loop
bool foundRoot = false;
for (;;) {
switch (xml.tokenType()) {
using enum QXmlStreamReader::TokenType;
case StartElement:
{
auto name = xml.name();
if (!foundRoot) {
if (name != "chat"_L1)
return std::unexpected(u"unexpected tag: %1"_s.arg(name));
foundRoot = true;
} else {
if (name != "user"_L1 && name != "assistant"_L1 && name != "system"_L1)
return std::unexpected(u"unknown role: %1"_s.arg(name));
auto content = xml.readElementText();
if (xml.tokenType() != EndElement)
return xmlError();
messages << makeJsonObject({
{ "role"_L1, name.toString().trimmed() },
{ "content"_L1, content },
});
}
break;
}
case Characters:
if (!xml.isWhitespace())
return std::unexpected(u"unexpected text: %1"_s.arg(xml.text()));
case Comment:
case ProcessingInstruction:
case EndElement:
break;
case EndDocument:
return messages;
case Invalid:
return xmlError();
default:
return std::unexpected(u"unexpected token: %1"_s.arg(xml.tokenString()));
}
xml.readNext();
} }
}
if (!promptCtx.n_past) { m_queuedPrompts.clear(); } void ChatAPI::prompt(
Q_ASSERT(promptCtx.n_past <= m_context.size()); std::string_view prompt,
m_context.resize(promptCtx.n_past); const PromptCallback &promptCallback,
const ResponseCallback &responseCallback,
const PromptContext &promptCtx
) {
Q_UNUSED(promptCallback)
// FIXME(cebtenzzre): We're assuming people don't try to use %2 with ChatGPT. What would that even mean? if (!isModelLoaded())
m_queuedPrompts << QString::fromStdString(promptTemplate).arg(QString::fromStdString(prompt)); throw std::invalid_argument("Attempted to prompt an unloaded model.");
if (!promptCtx.n_predict)
if (!promptCtx.n_predict && !fakeReply) { return; // nothing requested
return; // response explicitly suppressed, queue prompt for later
}
QString formattedPrompt = m_queuedPrompts.join("");
m_queuedPrompts.clear();
if (fakeReply) {
promptCtx.n_past += 1;
m_context.append(formattedPrompt);
m_context.append(QString::fromUtf8(fakeReply->data(), fakeReply->size()));
return;
}
// FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering // FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering
// an error we need to be able to count the tokens in our prompt. The only way to do this is to use // an error we need to be able to count the tokens in our prompt. The only way to do this is to use
// the OpenAI tiktokken library or to implement our own tokenization function that matches precisely // the OpenAI tiktoken library or to implement our own tokenization function that matches precisely
// the tokenization used by the OpenAI model we're calling. OpenAI has not introduced any means of // the tokenization used by the OpenAI model we're calling. OpenAI has not introduced any means of
// using the REST API to count tokens in a prompt. // using the REST API to count tokens in a prompt.
QJsonObject root; auto root = makeJsonObject({
root.insert("model", m_modelName); { "model"_L1, m_modelName },
root.insert("stream", true); { "stream"_L1, true },
root.insert("temperature", promptCtx.temp); { "temperature"_L1, promptCtx.temp },
root.insert("top_p", promptCtx.top_p); { "top_p"_L1, promptCtx.top_p },
});
// conversation history // conversation history
QJsonArray messages; {
for (int i = 0; i < m_context.count(); ++i) { QUtf8StringView promptUtf8(prompt);
QJsonObject message; QXmlStreamReader xml(promptUtf8);
message.insert("role", i % 2 == 0 ? "user" : "assistant"); auto messages = parsePrompt(xml);
message.insert("content", m_context.at(i)); if (!messages) {
messages.append(message); auto error = fmt::format("Failed to parse API model prompt: {}", messages.error());
qDebug().noquote() << "ChatAPI ERROR:" << error << "Prompt:\n\n" << promptUtf8 << '\n';
throw std::invalid_argument(error);
}
root.insert("messages"_L1, *messages);
} }
QJsonObject promptObject;
promptObject.insert("role", "user");
promptObject.insert("content", formattedPrompt);
messages.append(promptObject);
root.insert("messages", messages);
QJsonDocument doc(root); QJsonDocument doc(root);
#if defined(DEBUG) #if defined(DEBUG)
@ -148,12 +201,9 @@ void ChatAPI::prompt(const std::string &prompt,
connect(&worker, &ChatAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection); connect(&worker, &ChatAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(this, &ChatAPI::request, &worker, &ChatAPIWorker::request, Qt::QueuedConnection); connect(this, &ChatAPI::request, &worker, &ChatAPIWorker::request, Qt::QueuedConnection);
workerThread.start(); workerThread.start();
emit request(m_apiKey, &promptCtx, doc.toJson(QJsonDocument::Compact)); emit request(m_apiKey, doc.toJson(QJsonDocument::Compact));
workerThread.wait(); workerThread.wait();
promptCtx.n_past += 1;
m_context.append(formattedPrompt);
m_context.append(worker.currentResponse());
m_responseCallback = nullptr; m_responseCallback = nullptr;
#if defined(DEBUG) #if defined(DEBUG)
@ -171,12 +221,8 @@ bool ChatAPI::callResponse(int32_t token, const std::string& string)
return m_responseCallback(token, string); return m_responseCallback(token, string);
} }
void ChatAPIWorker::request(const QString &apiKey, void ChatAPIWorker::request(const QString &apiKey, const QByteArray &array)
LLModel::PromptContext *promptCtx,
const QByteArray &array)
{ {
m_ctx = promptCtx;
QUrl apiUrl(m_chat->url()); QUrl apiUrl(m_chat->url());
const QString authorization = u"Bearer %1"_s.arg(apiKey).trimmed(); const QString authorization = u"Bearer %1"_s.arg(apiKey).trimmed();
QNetworkRequest request(apiUrl); QNetworkRequest request(apiUrl);
@ -283,7 +329,6 @@ void ChatAPIWorker::handleReadyRead()
const QJsonObject choice = choices.first().toObject(); const QJsonObject choice = choices.first().toObject();
const QJsonObject delta = choice.value("delta").toObject(); const QJsonObject delta = choice.value("delta").toObject();
const QString content = delta.value("content").toString(); const QString content = delta.value("content").toString();
Q_ASSERT(m_ctx);
m_currentResponse += content; m_currentResponse += content;
if (!m_chat->callResponse(0, content.toStdString())) { if (!m_chat->callResponse(0, content.toStdString())) {
reply->abort(); reply->abort();

View File

@ -7,17 +7,14 @@
#include <QNetworkReply> #include <QNetworkReply>
#include <QObject> #include <QObject>
#include <QString> #include <QString>
#include <QStringList>
#include <QList>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <functional>
#include <optional>
#include <span> #include <span>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <unordered_map>
#include <vector> #include <vector>
class QNetworkAccessManager; class QNetworkAccessManager;
@ -28,16 +25,13 @@ class ChatAPIWorker : public QObject {
public: public:
ChatAPIWorker(ChatAPI *chatAPI) ChatAPIWorker(ChatAPI *chatAPI)
: QObject(nullptr) : QObject(nullptr)
, m_ctx(nullptr)
, m_networkManager(nullptr) , m_networkManager(nullptr)
, m_chat(chatAPI) {} , m_chat(chatAPI) {}
virtual ~ChatAPIWorker() {} virtual ~ChatAPIWorker() {}
QString currentResponse() const { return m_currentResponse; } QString currentResponse() const { return m_currentResponse; }
void request(const QString &apiKey, void request(const QString &apiKey, const QByteArray &array);
LLModel::PromptContext *promptCtx,
const QByteArray &array);
Q_SIGNALS: Q_SIGNALS:
void finished(); void finished();
@ -49,7 +43,6 @@ private Q_SLOTS:
private: private:
ChatAPI *m_chat; ChatAPI *m_chat;
LLModel::PromptContext *m_ctx;
QNetworkAccessManager *m_networkManager; QNetworkAccessManager *m_networkManager;
QString m_currentResponse; QString m_currentResponse;
}; };
@ -74,14 +67,14 @@ public:
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override
{ Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); } { Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); }
void prompt(const std::string &prompt, void prompt(std::string_view prompt,
const std::string &promptTemplate, const PromptCallback &promptCallback,
std::function<bool(int32_t)> promptCallback, const ResponseCallback &responseCallback,
std::function<bool(int32_t, const std::string&)> responseCallback, const PromptContext &ctx) override;
bool allowContextShift,
PromptContext &ctx, [[noreturn]]
bool special, int32_t countPromptTokens(std::string_view prompt) const override
std::optional<std::string_view> fakeReply) override; { Q_UNUSED(prompt); throwNotImplemented(); }
void setThreadCount(int32_t n_threads) override; void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override; int32_t threadCount() const override;
@ -91,19 +84,17 @@ public:
void setRequestURL(const QString &requestURL) { m_requestURL = requestURL; } void setRequestURL(const QString &requestURL) { m_requestURL = requestURL; }
QString url() const { return m_requestURL; } QString url() const { return m_requestURL; }
QList<QString> context() const { return m_context; }
void setContext(const QList<QString> &context) { m_context = context; }
bool callResponse(int32_t token, const std::string &string); bool callResponse(int32_t token, const std::string &string);
[[noreturn]] [[noreturn]]
int32_t contextLength() const override int32_t contextLength() const override
{ throwNotImplemented(); } { throwNotImplemented(); }
auto specialTokens() -> std::unordered_map<std::string, std::string> const override
{ return {}; }
Q_SIGNALS: Q_SIGNALS:
void request(const QString &apiKey, void request(const QString &apiKey, const QByteArray &array);
LLModel::PromptContext *ctx,
const QByteArray &array);
protected: protected:
// We have to implement these as they are pure virtual in base class, but we don't actually use // We have to implement these as they are pure virtual in base class, but we don't actually use
@ -114,8 +105,8 @@ protected:
static void throwNotImplemented() { throw std::logic_error("not implemented"); } static void throwNotImplemented() { throw std::logic_error("not implemented"); }
[[noreturn]] [[noreturn]]
std::vector<Token> tokenize(std::string_view str, bool special) override std::vector<Token> tokenize(std::string_view str) const override
{ Q_UNUSED(str); Q_UNUSED(special); throwNotImplemented(); } { Q_UNUSED(str); throwNotImplemented(); }
[[noreturn]] [[noreturn]]
bool isSpecialToken(Token id) const override bool isSpecialToken(Token id) const override
@ -126,7 +117,7 @@ protected:
{ Q_UNUSED(id); throwNotImplemented(); } { Q_UNUSED(id); throwNotImplemented(); }
[[noreturn]] [[noreturn]]
void initSampler(PromptContext &ctx) override void initSampler(const PromptContext &ctx) override
{ Q_UNUSED(ctx); throwNotImplemented(); } { Q_UNUSED(ctx); throwNotImplemented(); }
[[noreturn]] [[noreturn]]
@ -134,33 +125,28 @@ protected:
{ throwNotImplemented(); } { throwNotImplemented(); }
[[noreturn]] [[noreturn]]
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override bool evalTokens(int32_t nPast, std::span<const Token> tokens) const override
{ Q_UNUSED(ctx); Q_UNUSED(tokens); throwNotImplemented(); } { Q_UNUSED(nPast); Q_UNUSED(tokens); throwNotImplemented(); }
[[noreturn]] [[noreturn]]
void shiftContext(PromptContext &promptCtx) override void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override
{ Q_UNUSED(promptCtx); throwNotImplemented(); } { Q_UNUSED(promptCtx); Q_UNUSED(nPast); throwNotImplemented(); }
[[noreturn]] [[noreturn]]
int32_t inputLength() const override int32_t inputLength() const override
{ throwNotImplemented(); } { throwNotImplemented(); }
[[noreturn]] [[noreturn]]
void setTokenizeInputPosition(int32_t pos) override int32_t computeModelInputPosition(std::span<const Token> input) const override
{ Q_UNUSED(input); throwNotImplemented(); }
[[noreturn]]
void setModelInputPosition(int32_t pos) override
{ Q_UNUSED(pos); throwNotImplemented(); } { Q_UNUSED(pos); throwNotImplemented(); }
[[noreturn]] [[noreturn]]
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input) void appendInputToken(Token tok) override
-> std::vector<Token>::const_iterator override { Q_UNUSED(tok); throwNotImplemented(); }
{ Q_UNUSED(ctx); Q_UNUSED(input); throwNotImplemented(); }
[[noreturn]]
void setModelInputPosition(PromptContext &ctx, int32_t pos) override
{ Q_UNUSED(ctx); Q_UNUSED(pos); throwNotImplemented(); }
[[noreturn]]
void appendInputToken(PromptContext &ctx, Token tok) override
{ Q_UNUSED(ctx); Q_UNUSED(tok); throwNotImplemented(); }
[[noreturn]] [[noreturn]]
const std::vector<Token> &endTokens() const override const std::vector<Token> &endTokens() const override
@ -175,12 +161,10 @@ protected:
{ throwNotImplemented(); } { throwNotImplemented(); }
private: private:
std::function<bool(int32_t, const std::string&)> m_responseCallback; ResponseCallback m_responseCallback;
QString m_modelName; QString m_modelName;
QString m_apiKey; QString m_apiKey;
QString m_requestURL; QString m_requestURL;
QList<QString> m_context;
QStringList m_queuedPrompts;
}; };
#endif // CHATAPI_H #endif // CHATAPI_H

View File

@ -17,9 +17,10 @@
#include <Qt> #include <Qt>
#include <algorithm> #include <algorithm>
#include <memory>
#define CHAT_FORMAT_MAGIC 0xF5D553CC static constexpr quint32 CHAT_FORMAT_MAGIC = 0xF5D553CC;
#define CHAT_FORMAT_VERSION 10 static constexpr qint32 CHAT_FORMAT_VERSION = 11;
class MyChatListModel: public ChatListModel { }; class MyChatListModel: public ChatListModel { };
Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance) Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance)
@ -118,8 +119,8 @@ void ChatSaver::saveChats(const QVector<Chat *> &chats)
} }
QDataStream out(&tempFile); QDataStream out(&tempFile);
out << (quint32)CHAT_FORMAT_MAGIC; out << CHAT_FORMAT_MAGIC;
out << (qint32)CHAT_FORMAT_VERSION; out << CHAT_FORMAT_VERSION;
out.setVersion(QDataStream::Qt_6_2); out.setVersion(QDataStream::Qt_6_2);
qDebug() << "serializing chat" << fileName; qDebug() << "serializing chat" << fileName;
@ -257,12 +258,15 @@ void ChatsRestoreThread::run()
qDebug() << "deserializing chat" << f.file; qDebug() << "deserializing chat" << f.file;
Chat *chat = new Chat; auto chat = std::make_unique<Chat>();
chat->moveToThread(qGuiApp->thread()); chat->moveToThread(qGuiApp->thread());
if (!chat->deserialize(in, version)) { bool ok = chat->deserialize(in, version);
if (!ok) {
qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName(); qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName();
} else if (!in.atEnd()) {
qWarning().nospace() << "error loading chat from " << file.fileName() << ": extra data at end of file";
} else { } else {
emit chatRestored(chat); emit chatRestored(chat.release());
} }
if (f.oldFile) if (f.oldFile)
file.remove(); // No longer storing in this directory file.remove(); // No longer storing in this directory

File diff suppressed because it is too large Load Diff

View File

@ -13,20 +13,24 @@
#include <QObject> #include <QObject>
#include <QPointer> #include <QPointer>
#include <QString> #include <QString>
#include <QStringList> // IWYU pragma: keep
#include <QStringView>
#include <QThread> #include <QThread>
#include <QVariantMap> #include <QVariantMap> // IWYU pragma: keep
#include <QtGlobal> #include <QtGlobal>
#include <atomic> #include <atomic>
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <span>
#include <string> #include <string>
#include <vector> #include <variant>
using namespace Qt::Literals::StringLiterals; using namespace Qt::Literals::StringLiterals;
class QDataStream; class QDataStream;
struct ChatItem;
// NOTE: values serialized to disk, do not change or reuse // NOTE: values serialized to disk, do not change or reuse
enum class LLModelTypeV0 { // chat versions 2-5 enum class LLModelTypeV0 { // chat versions 2-5
@ -142,7 +146,6 @@ class Chat;
class ChatLLM : public QObject class ChatLLM : public QObject
{ {
Q_OBJECT Q_OBJECT
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
@ -150,12 +153,14 @@ public:
ChatLLM(Chat *parent, bool isServer = false); ChatLLM(Chat *parent, bool isServer = false);
virtual ~ChatLLM(); virtual ~ChatLLM();
void destroy();
static void destroyStore(); static void destroyStore();
static std::optional<std::string> checkJinjaTemplateError(const std::string &source);
void destroy();
bool isModelLoaded() const; bool isModelLoaded() const;
void regenerateResponse(); void regenerateResponse(int index);
void resetResponse(); // used to implement edit functionality
void resetContext(); std::optional<QString> popPrompt(int index);
void stopGenerating() { m_stopGenerating = true; } void stopGenerating() { m_stopGenerating = true; }
@ -165,13 +170,9 @@ public:
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; } void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; } void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
QString response(bool trim = true) const;
ModelInfo modelInfo() const; ModelInfo modelInfo() const;
void setModelInfo(const ModelInfo &info); void setModelInfo(const ModelInfo &info);
bool restoringFromText() const { return m_restoringFromText; }
void acquireModel(); void acquireModel();
void resetModel(); void resetModel();
@ -196,13 +197,11 @@ public:
return m_llModelInfo.fallbackReason.value_or(u""_s); return m_llModelInfo.fallbackReason.value_or(u""_s);
} }
QString generatedName() const { return QString::fromStdString(m_nameResponse); } bool serialize(QDataStream &stream, int version);
bool deserialize(QDataStream &stream, int version);
bool serialize(QDataStream &stream, int version, bool serializeKV);
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt); void prompt(const QStringList &enabledCollections);
bool loadDefaultModel(); bool loadDefaultModel();
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo); void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
bool loadModel(const ModelInfo &modelInfo); bool loadModel(const ModelInfo &modelInfo);
@ -210,22 +209,19 @@ public Q_SLOTS:
void unloadModel(); void unloadModel();
void reloadModel(); void reloadModel();
void generateName(); void generateName();
void generateQuestions(qint64 elapsed);
void handleChatIdChanged(const QString &id); void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged(); void handleShouldBeLoadedChanged();
void handleThreadStarted(); void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal); void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged(); void handleDeviceChanged();
void processSystemPrompt();
void processRestoreStateFromText();
Q_SIGNALS: Q_SIGNALS:
void restoringFromTextChanged();
void loadedModelInfoChanged(); void loadedModelInfoChanged();
void modelLoadingPercentageChanged(float); void modelLoadingPercentageChanged(float);
void modelLoadingError(const QString &error); void modelLoadingError(const QString &error);
void modelLoadingWarning(const QString &warning); void modelLoadingWarning(const QString &warning);
void responseChanged(const QString &response); void responseChanged(const QString &response);
void responseFailed(const QString &error);
void promptProcessing(); void promptProcessing();
void generatingQuestions(); void generatingQuestions();
void responseStopped(qint64 promptResponseMs); void responseStopped(qint64 promptResponseMs);
@ -244,58 +240,50 @@ Q_SIGNALS:
void modelInfoChanged(const ModelInfo &modelInfo); void modelInfoChanged(const ModelInfo &modelInfo);
protected: protected:
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate, struct PromptResult {
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, QByteArray response; // raw UTF-8
int32_t repeat_penalty_tokens, std::optional<QString> fakeReply = {}); int promptTokens; // note: counts *entire* history, even if cached
bool handlePrompt(int32_t token); int responseTokens;
bool handleResponse(int32_t token, const std::string &response); };
bool handleNamePrompt(int32_t token);
bool handleNameResponse(int32_t token, const std::string &response);
bool handleSystemPrompt(int32_t token);
bool handleSystemResponse(int32_t token, const std::string &response);
bool handleRestoreStateFromTextPrompt(int32_t token);
bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response);
bool handleQuestionPrompt(int32_t token);
bool handleQuestionResponse(int32_t token, const std::string &response);
void saveState();
void restoreState();
protected: struct ChatPromptResult : PromptResult {
LLModel::PromptContext m_ctx; QList<ResultInfo> databaseResults;
quint32 m_promptTokens; };
quint32 m_promptResponseTokens;
ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx);
// passing a string_view directly skips templating and uses the raw string
PromptResult promptInternal(const std::variant<std::span<const ChatItem>, std::string_view> &prompt,
const LLModel::PromptContext &ctx,
bool usedLocalDocs);
private: private:
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
std::vector<ChatItem> forkConversation(const QString &prompt) const;
// Applies the Jinja template. Query mode returns only the last message without special tokens.
// Returns a (# of messages, rendered prompt) pair.
std::string applyJinjaTemplate(std::span<const ChatItem> items) const;
void generateQuestions(qint64 elapsed);
protected:
QPointer<ChatModel> m_chatModel;
private:
const Chat *m_chat; const Chat *m_chat;
std::string m_response;
std::string m_trimmedResponse;
std::string m_nameResponse;
QString m_questionResponse;
LLModelInfo m_llModelInfo; LLModelInfo m_llModelInfo;
LLModelTypeV1 m_llModelType = LLModelTypeV1::NONE; LLModelTypeV1 m_llModelType = LLModelTypeV1::NONE;
ModelInfo m_modelInfo; ModelInfo m_modelInfo;
TokenTimer *m_timer; TokenTimer *m_timer;
QByteArray m_state;
std::vector<LLModel::Token> m_stateInputTokens;
int32_t m_stateContextLength = -1;
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded; std::atomic<bool> m_shouldBeLoaded;
std::atomic<bool> m_restoringFromText; // status indication
std::atomic<bool> m_forceUnloadModel; std::atomic<bool> m_forceUnloadModel;
std::atomic<bool> m_markedForDeletion; std::atomic<bool> m_markedForDeletion;
bool m_isServer; bool m_isServer;
bool m_forceMetal; bool m_forceMetal;
bool m_reloadingToChangeVariant; bool m_reloadingToChangeVariant;
bool m_processedSystemPrompt;
bool m_restoreStateFromText;
// m_pristineLoadedState is set if saveSate is unnecessary, either because:
// - an unload was queued during LLModel::restoreState()
// - the chat will be restored from text and hasn't been interacted with yet
bool m_pristineLoadedState = false;
QPointer<ChatModel> m_chatModel;
}; };
#endif // CHATLLM_H #endif // CHATLLM_H

View File

@ -2,8 +2,11 @@
#define CHATMODEL_H #define CHATMODEL_H
#include "database.h" #include "database.h"
#include "utils.h"
#include "xlsxtomd.h" #include "xlsxtomd.h"
#include <fmt/format.h>
#include <QAbstractListModel> #include <QAbstractListModel>
#include <QBuffer> #include <QBuffer>
#include <QByteArray> #include <QByteArray>
@ -18,6 +21,15 @@
#include <Qt> #include <Qt>
#include <QtGlobal> #include <QtGlobal>
#include <iterator>
#include <ranges>
#include <span>
#include <utility>
using namespace Qt::Literals::StringLiterals;
namespace ranges = std::ranges;
struct PromptAttachment { struct PromptAttachment {
Q_GADGET Q_GADGET
Q_PROPERTY(QUrl url MEMBER url) Q_PROPERTY(QUrl url MEMBER url)
@ -60,66 +72,145 @@ Q_DECLARE_METATYPE(PromptAttachment)
struct ChatItem struct ChatItem
{ {
Q_GADGET Q_GADGET
Q_PROPERTY(QString name MEMBER name) Q_PROPERTY(QString name MEMBER name )
Q_PROPERTY(QString value MEMBER value) Q_PROPERTY(QString value MEMBER value)
Q_PROPERTY(QString newResponse MEMBER newResponse)
Q_PROPERTY(bool currentResponse MEMBER currentResponse) // prompts
Q_PROPERTY(bool stopped MEMBER stopped)
Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState)
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)
Q_PROPERTY(QList<ResultInfo> sources MEMBER sources)
Q_PROPERTY(QList<ResultInfo> consolidatedSources MEMBER consolidatedSources)
Q_PROPERTY(QList<PromptAttachment> promptAttachments MEMBER promptAttachments) Q_PROPERTY(QList<PromptAttachment> promptAttachments MEMBER promptAttachments)
Q_PROPERTY(QString promptPlusAttachments READ promptPlusAttachments) Q_PROPERTY(QString bakedPrompt READ bakedPrompt )
// responses
Q_PROPERTY(bool isCurrentResponse MEMBER isCurrentResponse)
Q_PROPERTY(bool isError MEMBER isError )
// responses (DataLake)
Q_PROPERTY(QString newResponse MEMBER newResponse )
Q_PROPERTY(bool stopped MEMBER stopped )
Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState )
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)
public: public:
QString promptPlusAttachments() const enum class Type { System, Prompt, Response };
{
QStringList attachedContexts;
for (auto attached : promptAttachments)
attachedContexts << attached.processedContent();
QString promptPlus = value; // tags for constructing ChatItems
if (!attachedContexts.isEmpty()) struct prompt_tag_t { explicit prompt_tag_t() = default; };
promptPlus = attachedContexts.join("\n\n") + "\n\n" + value; static inline constexpr prompt_tag_t prompt_tag = prompt_tag_t();
return promptPlus; struct response_tag_t { explicit response_tag_t() = default; };
static inline constexpr response_tag_t response_tag = response_tag_t();
struct system_tag_t { explicit system_tag_t() = default; };
static inline constexpr system_tag_t system_tag = system_tag_t();
// FIXME(jared): This should not be necessary. QML should see null or undefined if it
// tries to access something invalid.
ChatItem() = default;
// NOTE: system messages are currently never stored in the model or serialized
ChatItem(system_tag_t, const QString &value)
: name(u"System: "_s), value(value) {}
ChatItem(prompt_tag_t, const QString &value, const QList<PromptAttachment> &attachments = {})
: name(u"Prompt: "_s), value(value), promptAttachments(attachments) {}
ChatItem(response_tag_t, bool isCurrentResponse = true)
: name(u"Response: "_s), isCurrentResponse(isCurrentResponse) {}
Type type() const
{
if (name == u"System: "_s)
return Type::System;
if (name == u"Prompt: "_s)
return Type::Prompt;
if (name == u"Response: "_s)
return Type::Response;
throw std::invalid_argument(fmt::format("Chat item has unknown label: {:?}", name));
}
// used with version 0 Jinja templates
QString bakedPrompt() const
{
if (type() != Type::Prompt)
throw std::logic_error("bakedPrompt() called on non-prompt item");
QStringList parts;
if (!sources.isEmpty()) {
parts << u"### Context:\n"_s;
for (auto &source : std::as_const(sources))
parts << u"Collection: "_s << source.collection
<< u"\nPath: "_s << source.path
<< u"\nExcerpt: "_s << source.text << u"\n\n"_s;
}
for (auto &attached : std::as_const(promptAttachments))
parts << attached.processedContent() << u"\n\n"_s;
parts << value;
return parts.join(QString());
} }
// TODO: Maybe we should include the model name here as well as timestamp? // TODO: Maybe we should include the model name here as well as timestamp?
QString name; QString name;
QString value; QString value;
QString newResponse;
QList<ResultInfo> sources; // prompts
QList<ResultInfo> consolidatedSources; QList<ResultInfo> sources;
QList<ResultInfo> consolidatedSources;
QList<PromptAttachment> promptAttachments; QList<PromptAttachment> promptAttachments;
bool currentResponse = false;
bool stopped = false; // responses
bool thumbsUpState = false; bool isCurrentResponse = false;
bool thumbsDownState = false; bool isError = false;
// responses (DataLake)
QString newResponse;
bool stopped = false;
bool thumbsUpState = false;
bool thumbsDownState = false;
}; };
Q_DECLARE_METATYPE(ChatItem) Q_DECLARE_METATYPE(ChatItem)
using ChatModelIterator = QList<ChatItem>::const_iterator; class ChatModelAccessor : public ranges::subrange<QList<ChatItem>::const_iterator> {
private:
using Super = ranges::subrange<QList<ChatItem>::const_iterator>;
public:
template <typename... T>
ChatModelAccessor(QMutex &mutex, T &&...args)
: Super(std::forward<T>(args)...), m_lock(&mutex) {}
private:
QMutexLocker<QMutex> m_lock;
};
class ChatModel : public QAbstractListModel class ChatModel : public QAbstractListModel
{ {
Q_OBJECT Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged) Q_PROPERTY(int count READ count NOTIFY countChanged)
Q_PROPERTY(bool hasError READ hasError NOTIFY hasErrorChanged)
public: public:
explicit ChatModel(QObject *parent = nullptr) : QAbstractListModel(parent) {} explicit ChatModel(QObject *parent = nullptr)
: QAbstractListModel(parent) {}
// FIXME(jared): can't this start at Qt::UserRole (no +1)?
enum Roles { enum Roles {
NameRole = Qt::UserRole + 1, NameRole = Qt::UserRole + 1,
ValueRole, ValueRole,
// prompts and responses
PeerRole,
// prompts
PromptAttachmentsRole,
// responses
// NOTE: sources are stored on the *prompts*, but in the model, they are only on the *responses*!
SourcesRole,
ConsolidatedSourcesRole,
IsCurrentResponseRole,
IsErrorRole,
// responses (DataLake)
NewResponseRole, NewResponseRole,
CurrentResponseRole,
StoppedRole, StoppedRole,
ThumbsUpStateRole, ThumbsUpStateRole,
ThumbsDownStateRole, ThumbsDownStateRole,
SourcesRole,
ConsolidatedSourcesRole,
PromptAttachmentsRole
}; };
int rowCount(const QModelIndex &parent = QModelIndex()) const override int rowCount(const QModelIndex &parent = QModelIndex()) const override
@ -129,34 +220,96 @@ public:
return m_chatItems.size(); return m_chatItems.size();
} }
/* a "peer" is a bidirectional 1:1 link between a prompt and the response that would cite its LocalDocs
* sources. Return std::nullopt if there is none, which is possible for e.g. server chats. */
auto getPeerUnlocked(QList<ChatItem>::const_iterator item) const
-> std::optional<QList<ChatItem>::const_iterator>
{
switch (item->type()) {
using enum ChatItem::Type;
case Prompt:
{
auto peer = std::next(item);
if (peer < m_chatItems.cend() && peer->type() == Response)
return peer;
break;
}
case Response:
{
if (item > m_chatItems.cbegin()) {
if (auto peer = std::prev(item); peer->type() == Prompt)
return peer;
}
break;
}
default:
throw std::invalid_argument("getPeer() called on item that is not a prompt or response");
}
return std::nullopt;
}
auto getPeerUnlocked(int index) const -> std::optional<int>
{
return getPeerUnlocked(m_chatItems.cbegin() + index)
.transform([&](auto &&i) { return i - m_chatItems.cbegin(); } );
}
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size()) if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size())
return QVariant(); return QVariant();
const ChatItem &item = m_chatItems.at(index.row()); auto item = m_chatItems.cbegin() + index.row();
switch (role) { switch (role) {
case NameRole: case NameRole:
return item.name; return item->name;
case ValueRole: case ValueRole:
return item.value; return item->value;
case NewResponseRole: case PeerRole:
return item.newResponse; switch (item->type()) {
case CurrentResponseRole: using enum ChatItem::Type;
return item.currentResponse; case Prompt:
case StoppedRole: case Response:
return item.stopped; {
case ThumbsUpStateRole: auto peer = getPeerUnlocked(item);
return item.thumbsUpState; return peer ? QVariant::fromValue(**peer) : QVariant::fromValue(nullptr);
case ThumbsDownStateRole: }
return item.thumbsDownState; default:
case SourcesRole: return QVariant();
return QVariant::fromValue(item.sources); }
case ConsolidatedSourcesRole:
return QVariant::fromValue(item.consolidatedSources);
case PromptAttachmentsRole: case PromptAttachmentsRole:
return QVariant::fromValue(item.promptAttachments); return QVariant::fromValue(item->promptAttachments);
case SourcesRole:
{
QList<ResultInfo> data;
if (item->type() == ChatItem::Type::Response) {
if (auto prompt = getPeerUnlocked(item))
data = (*prompt)->consolidatedSources;
}
return QVariant::fromValue(data);
}
case ConsolidatedSourcesRole:
{
QList<ResultInfo> data;
if (item->type() == ChatItem::Type::Response) {
if (auto prompt = getPeerUnlocked(item))
data = (*prompt)->sources;
}
return QVariant::fromValue(data);
}
case IsCurrentResponseRole:
return item->isCurrentResponse;
case NewResponseRole:
return item->newResponse;
case StoppedRole:
return item->stopped;
case ThumbsUpStateRole:
return item->thumbsUpState;
case ThumbsDownStateRole:
return item->thumbsDownState;
case IsErrorRole:
return item->type() == ChatItem::Type::Response && item->isError;
} }
return QVariant(); return QVariant();
@ -164,54 +317,126 @@ public:
QHash<int, QByteArray> roleNames() const override QHash<int, QByteArray> roleNames() const override
{ {
QHash<int, QByteArray> roles; return {
roles[NameRole] = "name"; { NameRole, "name" },
roles[ValueRole] = "value"; { ValueRole, "value" },
roles[NewResponseRole] = "newResponse"; { PeerRole, "peer" },
roles[CurrentResponseRole] = "currentResponse"; { PromptAttachmentsRole, "promptAttachments" },
roles[StoppedRole] = "stopped"; { SourcesRole, "sources" },
roles[ThumbsUpStateRole] = "thumbsUpState"; { ConsolidatedSourcesRole, "consolidatedSources" },
roles[ThumbsDownStateRole] = "thumbsDownState"; { IsCurrentResponseRole, "isCurrentResponse" },
roles[SourcesRole] = "sources"; { IsErrorRole, "isError" },
roles[ConsolidatedSourcesRole] = "consolidatedSources"; { NewResponseRole, "newResponse" },
roles[PromptAttachmentsRole] = "promptAttachments"; { StoppedRole, "stopped" },
return roles; { ThumbsUpStateRole, "thumbsUpState" },
{ ThumbsDownStateRole, "thumbsDownState" },
};
} }
void appendPrompt(const QString &name, const QString &value, const QList<PromptAttachment> &attachments) void appendPrompt(const QString &value, const QList<PromptAttachment> &attachments = {})
{ {
ChatItem item; qsizetype count;
item.name = name; {
item.value = value; QMutexLocker locker(&m_mutex);
item.promptAttachments << attachments; if (hasErrorUnlocked())
throw std::logic_error("cannot append to a failed chat");
count = m_chatItems.count();
}
m_mutex.lock();
const int count = m_chatItems.count();
m_mutex.unlock();
beginInsertRows(QModelIndex(), count, count); beginInsertRows(QModelIndex(), count, count);
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
m_chatItems.append(item); m_chatItems.emplace_back(ChatItem::prompt_tag, value, attachments);
} }
endInsertRows(); endInsertRows();
emit countChanged(); emit countChanged();
} }
void appendResponse(const QString &name) void appendResponse(int promptIndex)
{ {
m_mutex.lock(); qsizetype count;
const int count = m_chatItems.count(); {
m_mutex.unlock(); QMutexLocker locker(&m_mutex);
ChatItem item; if (hasErrorUnlocked())
item.name = name; throw std::logic_error("cannot append to a failed chat");
item.currentResponse = true; count = m_chatItems.count();
}
beginInsertRows(QModelIndex(), count, count); beginInsertRows(QModelIndex(), count, count);
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
m_chatItems.append(item); if (promptIndex >= 0) {
if (promptIndex >= m_chatItems.size())
throw std::out_of_range(fmt::format("index {} is out of range", promptIndex));
auto &promptItem = m_chatItems[promptIndex];
if (promptItem.type() != ChatItem::Type::Prompt)
throw std::invalid_argument(fmt::format("item at index {} is not a prompt", promptIndex));
}
m_chatItems.emplace_back(ChatItem::response_tag, promptIndex);
} }
endInsertRows(); endInsertRows();
emit countChanged(); emit countChanged();
if (promptIndex >= 0)
emit dataChanged(createIndex(promptIndex, 0), createIndex(promptIndex, 0), {PeerRole});
}
// Used by Server to append a new conversation to the chat log.
void appendResponseWithHistory(std::span<const ChatItem> history)
{
if (history.empty())
throw std::invalid_argument("at least one message is required");
m_mutex.lock();
qsizetype startIndex = m_chatItems.count();
m_mutex.unlock();
qsizetype nNewItems = history.size() + 1;
qsizetype endIndex = startIndex + nNewItems;
beginInsertRows(QModelIndex(), startIndex, endIndex - 1 /*inclusive*/);
bool hadError;
int promptIndex;
{
QMutexLocker locker(&m_mutex);
hadError = hasErrorUnlocked();
m_chatItems.reserve(m_chatItems.count() + nNewItems);
for (auto &item : history)
m_chatItems << item;
m_chatItems.emplace_back(ChatItem::response_tag);
}
endInsertRows();
emit countChanged();
// Server can add messages when there is an error because each call is a new conversation
if (hadError)
emit hasErrorChanged(false);
if (promptIndex >= 0)
emit dataChanged(createIndex(promptIndex, 0), createIndex(promptIndex, 0), {PeerRole});
}
void truncate(qsizetype size)
{
qsizetype oldSize;
{
QMutexLocker locker(&m_mutex);
if (size >= (oldSize = m_chatItems.size()))
return;
if (size && m_chatItems.at(size - 1).type() != ChatItem::Type::Response)
throw std::invalid_argument(
fmt::format("chat model truncated to {} items would not end in a response", size)
);
}
bool oldHasError;
beginRemoveRows(QModelIndex(), size, oldSize - 1 /*inclusive*/);
{
QMutexLocker locker(&m_mutex);
oldHasError = hasErrorUnlocked();
Q_ASSERT(size < m_chatItems.size());
m_chatItems.resize(size);
}
endRemoveRows();
emit countChanged();
if (oldHasError)
emit hasErrorChanged(false);
} }
Q_INVOKABLE void clear() Q_INVOKABLE void clear()
@ -221,13 +446,17 @@ public:
if (m_chatItems.isEmpty()) return; if (m_chatItems.isEmpty()) return;
} }
bool oldHasError;
beginResetModel(); beginResetModel();
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
oldHasError = hasErrorUnlocked();
m_chatItems.clear(); m_chatItems.clear();
} }
endResetModel(); endResetModel();
emit countChanged(); emit countChanged();
if (oldHasError)
emit hasErrorChanged(false);
} }
Q_INVOKABLE ChatItem get(int index) Q_INVOKABLE ChatItem get(int index)
@ -245,13 +474,13 @@ public:
if (index < 0 || index >= m_chatItems.size()) return; if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index]; ChatItem &item = m_chatItems[index];
if (item.currentResponse != b) { if (item.isCurrentResponse != b) {
item.currentResponse = b; item.isCurrentResponse = b;
changed = true; changed = true;
} }
} }
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole}); if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsCurrentResponseRole});
} }
Q_INVOKABLE void updateStopped(int index, bool b) Q_INVOKABLE void updateStopped(int index, bool b)
@ -304,16 +533,23 @@ public:
Q_INVOKABLE void updateSources(int index, const QList<ResultInfo> &sources) Q_INVOKABLE void updateSources(int index, const QList<ResultInfo> &sources)
{ {
int responseIndex = -1;
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return; if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index]; auto promptItem = m_chatItems.begin() + index;
item.sources = sources; if (promptItem->type() != ChatItem::Type::Prompt)
item.consolidatedSources = consolidateSources(sources); throw std::invalid_argument(fmt::format("item at index {} is not a prompt", index));
if (auto peer = getPeerUnlocked(promptItem))
responseIndex = *peer - m_chatItems.cbegin();
promptItem->sources = sources;
promptItem->consolidatedSources = consolidateSources(sources);
}
if (responseIndex >= 0) {
emit dataChanged(createIndex(responseIndex, 0), createIndex(responseIndex, 0), {SourcesRole});
emit dataChanged(createIndex(responseIndex, 0), createIndex(responseIndex, 0), {ConsolidatedSourcesRole});
} }
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole});
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole});
} }
Q_INVOKABLE void updateThumbsUpState(int index, bool b) Q_INVOKABLE void updateThumbsUpState(int index, bool b)
@ -364,18 +600,56 @@ public:
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole}); if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
} }
int count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); } Q_INVOKABLE void setError(bool value = true)
{
qsizetype index;
{
QMutexLocker locker(&m_mutex);
ChatModelIterator begin() const { return m_chatItems.begin(); } if (m_chatItems.isEmpty() || m_chatItems.cend()[-1].type() != ChatItem::Type::Response)
ChatModelIterator end() const { return m_chatItems.end(); } throw std::logic_error("can only set error on a chat that ends with a response");
void lock() { m_mutex.lock(); }
void unlock() { m_mutex.unlock(); } index = m_chatItems.count() - 1;
auto &last = m_chatItems.back();
if (last.isError == value)
return; // already set
last.isError = value;
}
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsErrorRole});
emit hasErrorChanged(value);
}
qsizetype count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); }
ChatModelAccessor chatItems() const { return {m_mutex, std::as_const(m_chatItems)}; }
bool hasError() const { QMutexLocker locker(&m_mutex); return hasErrorUnlocked(); }
bool serialize(QDataStream &stream, int version) const bool serialize(QDataStream &stream, int version) const
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
stream << int(m_chatItems.size()); stream << int(m_chatItems.size());
for (const auto &c : m_chatItems) { for (auto itemIt = m_chatItems.cbegin(); itemIt < m_chatItems.cend(); ++itemIt) {
auto c = *itemIt; // NB: copies
if (version < 11) {
// move sources from their prompt to the next response
switch (c.type()) {
using enum ChatItem::Type;
case Prompt:
c.sources.clear();
c.consolidatedSources.clear();
break;
case Response:
// note: we drop sources for responseless prompts
if (auto peer = getPeerUnlocked(itemIt)) {
c.sources = (*peer)->sources;
c.consolidatedSources = (*peer)->consolidatedSources;
}
default:
;
}
}
// FIXME: This 'id' should be eliminated the next time we bump serialization version. // FIXME: This 'id' should be eliminated the next time we bump serialization version.
// (Jared) This was apparently never used. // (Jared) This was apparently never used.
int id = 0; int id = 0;
@ -383,10 +657,12 @@ public:
stream << c.name; stream << c.name;
stream << c.value; stream << c.value;
stream << c.newResponse; stream << c.newResponse;
stream << c.currentResponse; stream << c.isCurrentResponse;
stream << c.stopped; stream << c.stopped;
stream << c.thumbsUpState; stream << c.thumbsUpState;
stream << c.thumbsDownState; stream << c.thumbsDownState;
if (version >= 11 && c.type() == ChatItem::Type::Response)
stream << c.isError;
if (version >= 8) { if (version >= 8) {
stream << c.sources.size(); stream << c.sources.size();
for (const ResultInfo &info : c.sources) { for (const ResultInfo &info : c.sources) {
@ -452,14 +728,24 @@ public:
bool deserialize(QDataStream &stream, int version) bool deserialize(QDataStream &stream, int version)
{ {
clear(); // reset to known state
int size; int size;
stream >> size; stream >> size;
int lastPromptIndex = -1;
QList<ChatItem> chatItems;
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
ChatItem c; ChatItem c;
// FIXME: see comment in serialization about id // FIXME: see comment in serialization about id
int id; int id;
stream >> id; stream >> id;
stream >> c.name; stream >> c.name;
try {
c.type(); // check name
} catch (const std::exception &e) {
qWarning() << "ChatModel ERROR:" << e.what();
return false;
}
stream >> c.value; stream >> c.value;
if (version < 10) { if (version < 10) {
// This is deprecated and no longer used // This is deprecated and no longer used
@ -467,10 +753,12 @@ public:
stream >> prompt; stream >> prompt;
} }
stream >> c.newResponse; stream >> c.newResponse;
stream >> c.currentResponse; stream >> c.isCurrentResponse;
stream >> c.stopped; stream >> c.stopped;
stream >> c.thumbsUpState; stream >> c.thumbsUpState;
stream >> c.thumbsDownState; stream >> c.thumbsDownState;
if (version >= 11 && c.type() == ChatItem::Type::Response)
stream >> c.isError;
if (version >= 8) { if (version >= 8) {
qsizetype count; qsizetype count;
stream >> count; stream >> count;
@ -587,23 +875,53 @@ public:
} }
c.promptAttachments = attachments; c.promptAttachments = attachments;
} }
m_mutex.lock();
const int count = m_chatItems.size(); if (version < 11 && c.type() == ChatItem::Type::Response) {
m_mutex.unlock(); // move sources from the response to their last prompt
beginInsertRows(QModelIndex(), count, count); if (lastPromptIndex >= 0) {
{ auto &prompt = chatItems[lastPromptIndex];
QMutexLocker locker(&m_mutex); prompt.sources = std::move(c.sources );
m_chatItems.append(c); prompt.consolidatedSources = std::move(c.consolidatedSources);
lastPromptIndex = -1;
} else {
// drop sources for promptless responses
c.sources.clear();
c.consolidatedSources.clear();
}
} }
endInsertRows();
chatItems << c;
if (c.type() == ChatItem::Type::Prompt)
lastPromptIndex = chatItems.size() - 1;
} }
bool hasError;
beginInsertRows(QModelIndex(), 0, chatItems.size() - 1 /*inclusive*/);
{
QMutexLocker locker(&m_mutex);
m_chatItems = chatItems;
hasError = hasErrorUnlocked();
}
endInsertRows();
emit countChanged(); emit countChanged();
if (hasError)
emit hasErrorChanged(true);
return stream.status() == QDataStream::Ok; return stream.status() == QDataStream::Ok;
} }
Q_SIGNALS: Q_SIGNALS:
void countChanged(); void countChanged();
void valueChanged(int index, const QString &value); void valueChanged(int index, const QString &value);
void hasErrorChanged(bool value);
private:
bool hasErrorUnlocked() const
{
if (m_chatItems.isEmpty())
return false;
auto &last = m_chatItems.back();
return last.type() == ChatItem::Type::Response && last.isError;
}
private: private:
mutable QMutex m_mutex; mutable QMutex m_mutex;

View File

@ -0,0 +1,111 @@
#include "jinja_helpers.h"
#include "utils.h"
#include <fmt/format.h>
#include <QString>
#include <QUrl>
#include <memory>
#include <vector>
using namespace std::literals::string_view_literals;
JinjaResultInfo::~JinjaResultInfo() = default;
const JinjaFieldMap<ResultInfo> JinjaResultInfo::s_fields = {
{ "collection", [](auto &s) { return s.collection.toStdString(); } },
{ "path", [](auto &s) { return s.path .toStdString(); } },
{ "file", [](auto &s) { return s.file .toStdString(); } },
{ "title", [](auto &s) { return s.title .toStdString(); } },
{ "author", [](auto &s) { return s.author .toStdString(); } },
{ "date", [](auto &s) { return s.date .toStdString(); } },
{ "text", [](auto &s) { return s.text .toStdString(); } },
{ "page", [](auto &s) { return s.page; } },
{ "file_uri", [](auto &s) { return s.fileUri() .toStdString(); } },
};
JinjaPromptAttachment::~JinjaPromptAttachment() = default;
const JinjaFieldMap<PromptAttachment> JinjaPromptAttachment::s_fields = {
{ "url", [](auto &s) { return s.url.toString() .toStdString(); } },
{ "file", [](auto &s) { return s.file() .toStdString(); } },
{ "processed_content", [](auto &s) { return s.processedContent().toStdString(); } },
};
std::vector<std::string> JinjaMessage::GetKeys() const
{
std::vector<std::string> result;
auto &keys = this->keys();
result.reserve(keys.size());
result.assign(keys.begin(), keys.end());
return result;
}
auto JinjaMessage::keys() const -> const std::unordered_set<std::string_view> &
{
static const std::unordered_set<std::string_view> baseKeys
{ "role", "content" };
static const std::unordered_set<std::string_view> userKeys
{ "role", "content", "sources", "prompt_attachments" };
switch (m_item->type()) {
using enum ChatItem::Type;
case System:
case Response:
return baseKeys;
case Prompt:
return userKeys;
}
Q_UNREACHABLE();
}
bool operator==(const JinjaMessage &a, const JinjaMessage &b)
{
if (a.m_item == b.m_item)
return true;
const auto &[ia, ib] = std::tie(*a.m_item, *b.m_item);
auto type = ia.type();
if (type != ib.type() || ia.value != ib.value)
return false;
switch (type) {
using enum ChatItem::Type;
case System:
case Response:
return true;
case Prompt:
return ia.sources == ib.sources && ia.promptAttachments == ib.promptAttachments;
}
Q_UNREACHABLE();
}
const JinjaFieldMap<JinjaMessage> JinjaMessage::s_fields = {
{ "role", [](auto &m) {
switch (m.item().type()) {
using enum ChatItem::Type;
case System: return "system"sv;
case Prompt: return "user"sv;
case Response: return "assistant"sv;
}
Q_UNREACHABLE();
} },
{ "content", [](auto &m) {
if (m.version() == 0 && m.item().type() == ChatItem::Type::Prompt)
return m.item().bakedPrompt().toStdString();
return m.item().value.toStdString();
} },
{ "sources", [](auto &m) {
auto sources = m.item().sources | views::transform([](auto &r) {
return jinja2::GenericMap([map = std::make_shared<JinjaResultInfo>(r)] { return map.get(); });
});
return jinja2::ValuesList(sources.begin(), sources.end());
} },
{ "prompt_attachments", [](auto &m) {
auto attachments = m.item().promptAttachments | views::transform([](auto &pa) {
return jinja2::GenericMap([map = std::make_shared<JinjaPromptAttachment>(pa)] { return map.get(); });
});
return jinja2::ValuesList(attachments.begin(), attachments.end());
} },
};

View File

@ -0,0 +1,116 @@
#pragma once
#include "chatmodel.h"
#include "database.h"
#include <jinja2cpp/value.h>
#include <functional>
#include <ranges>
#include <string>
#include <string_view>
#include <unordered_map>
#include <unordered_set>
#include <QtGlobal>
namespace views = std::views;
template <typename T>
using JinjaFieldMap = std::unordered_map<std::string_view, std::function<jinja2::Value (const T &)>>;
template <typename Derived>
class JinjaComparable : public jinja2::IMapItemAccessor {
public:
JinjaComparable() = default;
bool IsEqual(const jinja2::IComparable &other) const override;
private:
Q_DISABLE_COPY_MOVE(JinjaComparable)
};
template <typename Derived>
class JinjaHelper : public JinjaComparable<Derived> {
public:
size_t GetSize() const override
{ return Derived::s_fields.size(); }
bool HasValue(const std::string &name) const override
{ return Derived::s_fields.contains(name); }
jinja2::Value GetValueByName(const std::string &name) const override;
std::vector<std::string> GetKeys() const override
{ auto keys = views::elements<0>(Derived::s_fields); return { keys.begin(), keys.end() }; }
};
class JinjaResultInfo : public JinjaHelper<JinjaResultInfo> {
public:
explicit JinjaResultInfo(const ResultInfo &source) noexcept
: m_source(&source) {}
~JinjaResultInfo() override;
const ResultInfo &value() const { return *m_source; }
friend bool operator==(const JinjaResultInfo &a, const JinjaResultInfo &b)
{ return a.m_source == b.m_source || *a.m_source == *b.m_source; }
private:
static const JinjaFieldMap<ResultInfo> s_fields;
const ResultInfo *m_source;
friend class JinjaHelper<JinjaResultInfo>;
};
class JinjaPromptAttachment : public JinjaHelper<JinjaPromptAttachment> {
public:
explicit JinjaPromptAttachment(const PromptAttachment &attachment) noexcept
: m_attachment(&attachment) {}
~JinjaPromptAttachment() override;
const PromptAttachment &value() const { return *m_attachment; }
friend bool operator==(const JinjaPromptAttachment &a, const JinjaPromptAttachment &b)
{ return a.m_attachment == b.m_attachment || *a.m_attachment == *b.m_attachment; }
private:
static const JinjaFieldMap<PromptAttachment> s_fields;
const PromptAttachment *m_attachment;
friend class JinjaHelper<JinjaPromptAttachment>;
};
class JinjaMessage : public JinjaHelper<JinjaMessage> {
public:
explicit JinjaMessage(uint version, const ChatItem &item) noexcept
: m_version(version), m_item(&item) {}
const JinjaMessage &value () const { return *this; }
uint version() const { return m_version; }
const ChatItem &item () const { return *m_item; }
size_t GetSize() const override { return keys().size(); }
bool HasValue(const std::string &name) const override { return keys().contains(name); }
jinja2::Value GetValueByName(const std::string &name) const override
{ return HasValue(name) ? JinjaHelper::GetValueByName(name) : jinja2::EmptyValue(); }
std::vector<std::string> GetKeys() const override;
private:
auto keys() const -> const std::unordered_set<std::string_view> &;
private:
static const JinjaFieldMap<JinjaMessage> s_fields;
uint m_version;
const ChatItem *m_item;
friend class JinjaHelper<JinjaMessage>;
friend bool operator==(const JinjaMessage &a, const JinjaMessage &b);
};
#include "jinja_helpers.inl"

View File

@ -0,0 +1,17 @@
template <typename D>
bool JinjaComparable<D>::IsEqual(const jinja2::IComparable &other) const
{
if (auto *omsg = dynamic_cast<const D *>(&other))
return *static_cast<const D *>(this) == *omsg;
return false;
}
template <typename D>
jinja2::Value JinjaHelper<D>::GetValueByName(const std::string &name) const
{
if (auto it = D::s_fields.find(name); it != D::s_fields.end()) {
auto [_, func] = *it;
return func(static_cast<const D *>(this)->value());
}
return jinja2::EmptyValue();
}

View File

@ -12,12 +12,16 @@
#include <singleapplication.h> #include <singleapplication.h>
#include <QCoreApplication> #include <QCoreApplication>
#include <QFont>
#include <QFontDatabase>
#include <QObject> #include <QObject>
#include <QQmlApplicationEngine> #include <QQmlApplicationEngine>
#include <QQmlContext>
#include <QQuickWindow> #include <QQuickWindow>
#include <QSettings> #include <QSettings>
#include <QString> #include <QString>
#include <QUrl> #include <QUrl>
#include <QVariant>
#include <Qt> #include <Qt>
#ifdef Q_OS_LINUX #ifdef Q_OS_LINUX
@ -91,18 +95,22 @@ int main(int argc, char *argv[])
// Set the local and language translation before the qml engine has even been started. This will // Set the local and language translation before the qml engine has even been started. This will
// use the default system locale unless the user has explicitly set it to use a different one. // use the default system locale unless the user has explicitly set it to use a different one.
MySettings::globalInstance()->setLanguageAndLocale(); auto *mySettings = MySettings::globalInstance();
mySettings->setLanguageAndLocale();
QQmlApplicationEngine engine; QQmlApplicationEngine engine;
// Add a connection here from MySettings::languageAndLocaleChanged signal to a lambda slot where I can call // Add a connection here from MySettings::languageAndLocaleChanged signal to a lambda slot where I can call
// engine.uiLanguage property // engine.uiLanguage property
QObject::connect(MySettings::globalInstance(), &MySettings::languageAndLocaleChanged, [&engine]() { QObject::connect(mySettings, &MySettings::languageAndLocaleChanged, [&engine]() {
engine.setUiLanguage(MySettings::globalInstance()->languageAndLocale()); engine.setUiLanguage(MySettings::globalInstance()->languageAndLocale());
}); });
qmlRegisterSingletonInstance("mysettings", 1, 0, "MySettings", MySettings::globalInstance()); auto *modelList = ModelList::globalInstance();
qmlRegisterSingletonInstance("modellist", 1, 0, "ModelList", ModelList::globalInstance()); QObject::connect(modelList, &ModelList::dataChanged, mySettings, &MySettings::onModelInfoChanged);
qmlRegisterSingletonInstance("mysettings", 1, 0, "MySettings", mySettings);
qmlRegisterSingletonInstance("modellist", 1, 0, "ModelList", modelList);
qmlRegisterSingletonInstance("chatlistmodel", 1, 0, "ChatListModel", ChatListModel::globalInstance()); qmlRegisterSingletonInstance("chatlistmodel", 1, 0, "ChatListModel", ChatListModel::globalInstance());
qmlRegisterSingletonInstance("llm", 1, 0, "LLM", LLM::globalInstance()); qmlRegisterSingletonInstance("llm", 1, 0, "LLM", LLM::globalInstance());
qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance()); qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance());
@ -110,6 +118,11 @@ int main(int argc, char *argv[])
qmlRegisterSingletonInstance("localdocs", 1, 0, "LocalDocs", LocalDocs::globalInstance()); qmlRegisterSingletonInstance("localdocs", 1, 0, "LocalDocs", LocalDocs::globalInstance());
qmlRegisterUncreatableMetaObject(MySettingsEnums::staticMetaObject, "mysettingsenums", 1, 0, "MySettingsEnums", "Error: only enums"); qmlRegisterUncreatableMetaObject(MySettingsEnums::staticMetaObject, "mysettingsenums", 1, 0, "MySettingsEnums", "Error: only enums");
{
auto fixedFont = QFontDatabase::systemFont(QFontDatabase::FixedFont);
engine.rootContext()->setContextProperty("fixedFont", fixedFont);
}
const QUrl url(u"qrc:/gpt4all/main.qml"_s); const QUrl url(u"qrc:/gpt4all/main.qml"_s);
QObject::connect(&engine, &QQmlApplicationEngine::objectCreated, QObject::connect(&engine, &QQmlApplicationEngine::objectCreated,

View File

@ -316,26 +316,44 @@ void ModelInfo::setRepeatPenaltyTokens(int t)
m_repeatPenaltyTokens = t; m_repeatPenaltyTokens = t;
} }
QString ModelInfo::promptTemplate() const QVariant ModelInfo::defaultChatTemplate() const
{ {
return MySettings::globalInstance()->modelPromptTemplate(*this); auto res = m_chatTemplate.or_else([this] -> std::optional<QString> {
if (!installed || isOnline)
return std::nullopt;
if (!m_modelChatTemplate) {
auto path = (dirpath + filename()).toUtf8();
auto res = LLModel::Implementation::chatTemplate(path.constData());
if (res) {
m_modelChatTemplate = QString::fromStdString(*res);
} else {
qWarning().nospace() << "failed to get chat template for " << filename() << ": " << res.error().c_str();
m_modelChatTemplate = QString(); // do not retry
}
}
if (m_modelChatTemplate->isNull())
return std::nullopt;
return m_modelChatTemplate;
});
if (res)
return std::move(*res);
return QVariant::fromValue(nullptr);
} }
void ModelInfo::setPromptTemplate(const QString &t) auto ModelInfo::chatTemplate() const -> UpgradeableSetting
{ {
if (shouldSaveMetadata()) MySettings::globalInstance()->setModelPromptTemplate(*this, t, true /*force*/); return MySettings::globalInstance()->modelChatTemplate(*this);
m_promptTemplate = t;
} }
QString ModelInfo::systemPrompt() const QString ModelInfo::defaultSystemMessage() const
{ {
return MySettings::globalInstance()->modelSystemPrompt(*this); return m_systemMessage;
} }
void ModelInfo::setSystemPrompt(const QString &p) auto ModelInfo::systemMessage() const -> UpgradeableSetting
{ {
if (shouldSaveMetadata()) MySettings::globalInstance()->setModelSystemPrompt(*this, p, true /*force*/); return MySettings::globalInstance()->modelSystemMessage(*this);
m_systemPrompt = p;
} }
QString ModelInfo::chatNamePrompt() const QString ModelInfo::chatNamePrompt() const
@ -360,39 +378,41 @@ void ModelInfo::setSuggestedFollowUpPrompt(const QString &p)
m_suggestedFollowUpPrompt = p; m_suggestedFollowUpPrompt = p;
} }
// FIXME(jared): this should not be used for model settings that have meaningful defaults, such as temperature
bool ModelInfo::shouldSaveMetadata() const bool ModelInfo::shouldSaveMetadata() const
{ {
return installed && (isClone() || isDiscovered() || description() == "" /*indicates sideloaded*/); return installed && (isClone() || isDiscovered() || description() == "" /*indicates sideloaded*/);
} }
QVariantMap ModelInfo::getFields() const QVariant ModelInfo::getField(QLatin1StringView name) const
{ {
return { static const std::unordered_map<QLatin1StringView, QVariant(*)(const ModelInfo &)> s_fields = {
{ "filename", m_filename }, { "filename"_L1, [](auto &i) -> QVariant { return i.m_filename; } },
{ "description", m_description }, { "description"_L1, [](auto &i) -> QVariant { return i.m_description; } },
{ "url", m_url }, { "url"_L1, [](auto &i) -> QVariant { return i.m_url; } },
{ "quant", m_quant }, { "quant"_L1, [](auto &i) -> QVariant { return i.m_quant; } },
{ "type", m_type }, { "type"_L1, [](auto &i) -> QVariant { return i.m_type; } },
{ "isClone", m_isClone }, { "isClone"_L1, [](auto &i) -> QVariant { return i.m_isClone; } },
{ "isDiscovered", m_isDiscovered }, { "isDiscovered"_L1, [](auto &i) -> QVariant { return i.m_isDiscovered; } },
{ "likes", m_likes }, { "likes"_L1, [](auto &i) -> QVariant { return i.m_likes; } },
{ "downloads", m_downloads }, { "downloads"_L1, [](auto &i) -> QVariant { return i.m_downloads; } },
{ "recency", m_recency }, { "recency"_L1, [](auto &i) -> QVariant { return i.m_recency; } },
{ "temperature", m_temperature }, { "temperature"_L1, [](auto &i) -> QVariant { return i.m_temperature; } },
{ "topP", m_topP }, { "topP"_L1, [](auto &i) -> QVariant { return i.m_topP; } },
{ "minP", m_minP }, { "minP"_L1, [](auto &i) -> QVariant { return i.m_minP; } },
{ "topK", m_topK }, { "topK"_L1, [](auto &i) -> QVariant { return i.m_topK; } },
{ "maxLength", m_maxLength }, { "maxLength"_L1, [](auto &i) -> QVariant { return i.m_maxLength; } },
{ "promptBatchSize", m_promptBatchSize }, { "promptBatchSize"_L1, [](auto &i) -> QVariant { return i.m_promptBatchSize; } },
{ "contextLength", m_contextLength }, { "contextLength"_L1, [](auto &i) -> QVariant { return i.m_contextLength; } },
{ "gpuLayers", m_gpuLayers }, { "gpuLayers"_L1, [](auto &i) -> QVariant { return i.m_gpuLayers; } },
{ "repeatPenalty", m_repeatPenalty }, { "repeatPenalty"_L1, [](auto &i) -> QVariant { return i.m_repeatPenalty; } },
{ "repeatPenaltyTokens", m_repeatPenaltyTokens }, { "repeatPenaltyTokens"_L1, [](auto &i) -> QVariant { return i.m_repeatPenaltyTokens; } },
{ "promptTemplate", m_promptTemplate }, { "chatTemplate"_L1, [](auto &i) -> QVariant { return i.defaultChatTemplate(); } },
{ "systemPrompt", m_systemPrompt }, { "systemMessage"_L1, [](auto &i) -> QVariant { return i.m_systemMessage; } },
{ "chatNamePrompt", m_chatNamePrompt }, { "chatNamePrompt"_L1, [](auto &i) -> QVariant { return i.m_chatNamePrompt; } },
{ "suggestedFollowUpPrompt", m_suggestedFollowUpPrompt }, { "suggestedFollowUpPrompt"_L1, [](auto &i) -> QVariant { return i.m_suggestedFollowUpPrompt; } },
}; };
return s_fields.at(name)(*this);
} }
InstalledModels::InstalledModels(QObject *parent, bool selectable) InstalledModels::InstalledModels(QObject *parent, bool selectable)
@ -491,31 +511,48 @@ ModelList::ModelList()
m_selectableModels->setSourceModel(this); m_selectableModels->setSourceModel(this);
m_downloadableModels->setSourceModel(this); m_downloadableModels->setSourceModel(this);
connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromDirectory); auto *mySettings = MySettings::globalInstance();
connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromJson); connect(mySettings, &MySettings::nameChanged, this, &ModelList::updateDataForSettings );
connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromSettings); connect(mySettings, &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings );
connect(MySettings::globalInstance(), &MySettings::nameChanged, this, &ModelList::updateDataForSettings); connect(mySettings, &MySettings::topPChanged, this, &ModelList::updateDataForSettings );
connect(MySettings::globalInstance(), &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings); connect(mySettings, &MySettings::minPChanged, this, &ModelList::updateDataForSettings );
connect(MySettings::globalInstance(), &MySettings::topPChanged, this, &ModelList::updateDataForSettings); connect(mySettings, &MySettings::topKChanged, this, &ModelList::updateDataForSettings );
connect(MySettings::globalInstance(), &MySettings::minPChanged, this, &ModelList::updateDataForSettings); connect(mySettings, &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings );
connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings); connect(mySettings, &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings );
connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings); connect(mySettings, &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings );
connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings); connect(mySettings, &MySettings::gpuLayersChanged, this, &ModelList::updateDataForSettings );
connect(MySettings::globalInstance(), &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings); connect(mySettings, &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings );
connect(MySettings::globalInstance(), &MySettings::gpuLayersChanged, this, &ModelList::updateDataForSettings); connect(mySettings, &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings );
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings); connect(mySettings, &MySettings::chatTemplateChanged, this, &ModelList::maybeUpdateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings); connect(mySettings, &MySettings::systemMessageChanged, this, &ModelList::maybeUpdateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::systemPromptChanged, this, &ModelList::updateDataForSettings); connect(this, &ModelList::dataChanged, this, &ModelList::onDataChanged);
connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, &ModelList::handleSslErrors); connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, &ModelList::handleSslErrors);
updateModelsFromJson(); updateModelsFromJson();
updateModelsFromSettings(); updateModelsFromSettings();
updateModelsFromDirectory(); updateModelsFromDirectory();
connect(mySettings, &MySettings::modelPathChanged, this, &ModelList::updateModelsFromDirectory);
connect(mySettings, &MySettings::modelPathChanged, this, &ModelList::updateModelsFromJson );
connect(mySettings, &MySettings::modelPathChanged, this, &ModelList::updateModelsFromSettings );
QCoreApplication::instance()->installEventFilter(this); QCoreApplication::instance()->installEventFilter(this);
} }
// an easier way to listen for model info and setting changes
void ModelList::onDataChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList<int> &roles)
{
Q_UNUSED(roles)
for (int row = topLeft.row(); row <= bottomRight.row(); row++) {
auto index = topLeft.siblingAtRow(row);
auto id = index.data(ModelList::IdRole).toString();
if (auto info = modelInfo(id); !info.id().isNull())
emit modelInfoChanged(info);
}
}
QString ModelList::compatibleModelNameHash(QUrl baseUrl, QString modelName) { QString ModelList::compatibleModelNameHash(QUrl baseUrl, QString modelName) {
QCryptographicHash sha256(QCryptographicHash::Sha256); QCryptographicHash sha256(QCryptographicHash::Sha256);
sha256.addData((baseUrl.toString() + "_" + modelName).toUtf8()); sha256.addData((baseUrl.toString() + "_" + modelName).toUtf8());
@ -776,10 +813,10 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
return info->repeatPenalty(); return info->repeatPenalty();
case RepeatPenaltyTokensRole: case RepeatPenaltyTokensRole:
return info->repeatPenaltyTokens(); return info->repeatPenaltyTokens();
case PromptTemplateRole: case ChatTemplateRole:
return info->promptTemplate(); return QVariant::fromValue(info->chatTemplate());
case SystemPromptRole: case SystemMessageRole:
return info->systemPrompt(); return QVariant::fromValue(info->systemMessage());
case ChatNamePromptRole: case ChatNamePromptRole:
return info->chatNamePrompt(); return info->chatNamePrompt();
case SuggestedFollowUpPromptRole: case SuggestedFollowUpPromptRole:
@ -952,10 +989,10 @@ void ModelList::updateData(const QString &id, const QVector<QPair<int, QVariant>
info->setRepeatPenalty(value.toDouble()); break; info->setRepeatPenalty(value.toDouble()); break;
case RepeatPenaltyTokensRole: case RepeatPenaltyTokensRole:
info->setRepeatPenaltyTokens(value.toInt()); break; info->setRepeatPenaltyTokens(value.toInt()); break;
case PromptTemplateRole: case ChatTemplateRole:
info->setPromptTemplate(value.toString()); break; info->m_chatTemplate = value.toString(); break;
case SystemPromptRole: case SystemMessageRole:
info->setSystemPrompt(value.toString()); break; info->m_systemMessage = value.toString(); break;
case ChatNamePromptRole: case ChatNamePromptRole:
info->setChatNamePrompt(value.toString()); break; info->setChatNamePrompt(value.toString()); break;
case SuggestedFollowUpPromptRole: case SuggestedFollowUpPromptRole:
@ -1056,11 +1093,11 @@ ModelInfo ModelList::modelInfo(const QString &id) const
return *m_modelMap.value(id); return *m_modelMap.value(id);
} }
ModelInfo ModelList::modelInfoByFilename(const QString &filename) const ModelInfo ModelList::modelInfoByFilename(const QString &filename, bool allowClone) const
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
for (ModelInfo *info : m_models) for (ModelInfo *info : m_models)
if (info->filename() == filename) if (info->filename() == filename && (allowClone || !info->isClone()))
return *info; return *info;
return ModelInfo(); return ModelInfo();
} }
@ -1080,6 +1117,20 @@ QString ModelList::clone(const ModelInfo &model)
const QString id = Network::globalInstance()->generateUniqueId(); const QString id = Network::globalInstance()->generateUniqueId();
addModel(id); addModel(id);
QString chatTemplate, systemMessage;
if (auto tmpl = model.chatTemplate().asModern()) {
chatTemplate = *tmpl;
} else {
qWarning("ModelList Warning: attempted to clone model with legacy chat template");
return {};
}
if (auto msg = model.systemMessage().asModern()) {
systemMessage = *msg;
} else {
qWarning("ModelList Warning: attempted to clone model with legacy system message");
return {};
}
QVector<QPair<int, QVariant>> data { QVector<QPair<int, QVariant>> data {
{ ModelList::InstalledRole, model.installed }, { ModelList::InstalledRole, model.installed },
{ ModelList::IsCloneRole, true }, { ModelList::IsCloneRole, true },
@ -1099,8 +1150,8 @@ QString ModelList::clone(const ModelInfo &model)
{ ModelList::GpuLayersRole, model.gpuLayers() }, { ModelList::GpuLayersRole, model.gpuLayers() },
{ ModelList::RepeatPenaltyRole, model.repeatPenalty() }, { ModelList::RepeatPenaltyRole, model.repeatPenalty() },
{ ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens() }, { ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens() },
{ ModelList::PromptTemplateRole, model.promptTemplate() }, { ModelList::ChatTemplateRole, chatTemplate },
{ ModelList::SystemPromptRole, model.systemPrompt() }, { ModelList::SystemMessageRole, systemMessage },
{ ModelList::ChatNamePromptRole, model.chatNamePrompt() }, { ModelList::ChatNamePromptRole, model.chatNamePrompt() },
{ ModelList::SuggestedFollowUpPromptRole, model.suggestedFollowUpPrompt() }, { ModelList::SuggestedFollowUpPromptRole, model.suggestedFollowUpPrompt() },
}; };
@ -1125,21 +1176,23 @@ void ModelList::removeInstalled(const ModelInfo &model)
removeInternal(model); removeInternal(model);
} }
int ModelList::indexByModelId(const QString &id) const
{
QMutexLocker locker(&m_mutex);
if (auto it = m_modelMap.find(id); it != m_modelMap.cend())
return m_models.indexOf(*it);
return -1;
}
void ModelList::removeInternal(const ModelInfo &model) void ModelList::removeInternal(const ModelInfo &model)
{ {
const bool hasModel = contains(model.id()); int indexOfModel = indexByModelId(model.id());
Q_ASSERT(hasModel); Q_ASSERT(indexOfModel != -1);
if (!hasModel) { if (indexOfModel == -1) {
qWarning() << "ERROR: model list does not contain" << model.id(); qWarning() << "ERROR: model list does not contain" << model.id();
return; return;
} }
int indexOfModel = 0;
{
QMutexLocker locker(&m_mutex);
ModelInfo *info = m_modelMap.value(model.id());
indexOfModel = m_models.indexOf(info);
}
beginRemoveRows(QModelIndex(), indexOfModel, indexOfModel); beginRemoveRows(QModelIndex(), indexOfModel, indexOfModel);
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
@ -1314,8 +1367,6 @@ void ModelList::processModelDirectory(const QString &path)
// The description is hard-coded into "GPT4All.ini" due to performance issue. // The description is hard-coded into "GPT4All.ini" due to performance issue.
// If the description goes to be dynamic from its .rmodel file, it will get high I/O usage while using the ModelList. // If the description goes to be dynamic from its .rmodel file, it will get high I/O usage while using the ModelList.
data.append({ DescriptionRole, description }); data.append({ DescriptionRole, description });
// Prompt template should be clear while using ChatML format which is using in most of OpenAI-Compatible API server.
data.append({ PromptTemplateRole, "%1" });
} }
updateData(id, data); updateData(id, data);
} }
@ -1451,9 +1502,20 @@ void ModelList::handleSslErrors(QNetworkReply *reply, const QList<QSslError> &er
qWarning() << "ERROR: Received ssl error:" << e.errorString() << "for" << url; qWarning() << "ERROR: Received ssl error:" << e.errorString() << "for" << url;
} }
void ModelList::maybeUpdateDataForSettings(const ModelInfo &info, bool fromInfo)
{
// ignore updates that were *because* of a dataChanged - would cause a circular dependency
int idx;
if (!fromInfo && (idx = indexByModelId(info.id())) != -1) {
emit dataChanged(index(idx, 0), index(idx, 0));
emit selectableModelListChanged();
}
}
void ModelList::updateDataForSettings() void ModelList::updateDataForSettings()
{ {
emit dataChanged(index(0, 0), index(m_models.size() - 1, 0)); emit dataChanged(index(0, 0), index(m_models.size() - 1, 0));
emit selectableModelListChanged();
} }
void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
@ -1560,10 +1622,10 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
data.append({ ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble() }); data.append({ ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble() });
if (obj.contains("repeatPenaltyTokens")) if (obj.contains("repeatPenaltyTokens"))
data.append({ ModelList::RepeatPenaltyTokensRole, obj["repeatPenaltyTokens"].toInt() }); data.append({ ModelList::RepeatPenaltyTokensRole, obj["repeatPenaltyTokens"].toInt() });
if (obj.contains("promptTemplate")) if (auto it = obj.find("chatTemplate"_L1); it != obj.end())
data.append({ ModelList::PromptTemplateRole, obj["promptTemplate"].toString() }); data.append({ ModelList::ChatTemplateRole, it->toString() });
if (obj.contains("systemPrompt")) if (auto it = obj.find("systemMessage"_L1); it != obj.end())
data.append({ ModelList::SystemPromptRole, obj["systemPrompt"].toString() }); data.append({ ModelList::SystemMessageRole, it->toString() });
updateData(id, data); updateData(id, data);
} }
@ -1755,6 +1817,9 @@ void ModelList::updateDiscoveredInstalled(const ModelInfo &info)
updateData(info.id(), data); updateData(info.id(), data);
} }
// FIXME(jared): This should only contain fields without reasonable defaults such as name, description, and URL.
// For other settings, there is no authoritative value and we should load the setting lazily like we do
// for any other override.
void ModelList::updateModelsFromSettings() void ModelList::updateModelsFromSettings()
{ {
QSettings settings; QSettings settings;
@ -1769,12 +1834,27 @@ void ModelList::updateModelsFromSettings()
// If we can't find the corresponding file, then ignore it as this reflects a stale model. // If we can't find the corresponding file, then ignore it as this reflects a stale model.
// The file could have been deleted manually by the user for instance or temporarily renamed. // The file could have been deleted manually by the user for instance or temporarily renamed.
if (!settings.contains(g + "/filename") || !modelExists(settings.value(g + "/filename").toString())) QString filename;
continue; {
auto value = settings.value(u"%1/filename"_s.arg(g));
if (!value.isValid() || !modelExists(filename = value.toString()))
continue;
}
QVector<QPair<int, QVariant>> data;
// load data from base model
// FIXME(jared): how does "Restore Defaults" work for other settings of clones which we don't do this for?
if (auto base = modelInfoByFilename(filename, /*allowClone*/ false); !base.id().isNull()) {
if (auto tmpl = base.m_chatTemplate)
data.append({ ModelList::ChatTemplateRole, *tmpl });
if (auto msg = base.m_systemMessage; !msg.isNull())
data.append({ ModelList::SystemMessageRole, msg });
}
addModel(id); addModel(id);
QVector<QPair<int, QVariant>> data; // load data from settings
if (settings.contains(g + "/name")) { if (settings.contains(g + "/name")) {
const QString name = settings.value(g + "/name").toString(); const QString name = settings.value(g + "/name").toString();
data.append({ ModelList::NameRole, name }); data.append({ ModelList::NameRole, name });
@ -1859,14 +1939,6 @@ void ModelList::updateModelsFromSettings()
const int repeatPenaltyTokens = settings.value(g + "/repeatPenaltyTokens").toInt(); const int repeatPenaltyTokens = settings.value(g + "/repeatPenaltyTokens").toInt();
data.append({ ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens }); data.append({ ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens });
} }
if (settings.contains(g + "/promptTemplate")) {
const QString promptTemplate = settings.value(g + "/promptTemplate").toString();
data.append({ ModelList::PromptTemplateRole, promptTemplate });
}
if (settings.contains(g + "/systemPrompt")) {
const QString systemPrompt = settings.value(g + "/systemPrompt").toString();
data.append({ ModelList::SystemPromptRole, systemPrompt });
}
if (settings.contains(g + "/chatNamePrompt")) { if (settings.contains(g + "/chatNamePrompt")) {
const QString chatNamePrompt = settings.value(g + "/chatNamePrompt").toString(); const QString chatNamePrompt = settings.value(g + "/chatNamePrompt").toString();
data.append({ ModelList::ChatNamePromptRole, chatNamePrompt }); data.append({ ModelList::ChatNamePromptRole, chatNamePrompt });

View File

@ -5,12 +5,14 @@
#include <QByteArray> #include <QByteArray>
#include <QDateTime> #include <QDateTime>
#include <QHash> #include <QHash>
#include <QLatin1StringView>
#include <QList> #include <QList>
#include <QMutex> #include <QMutex>
#include <QNetworkAccessManager> #include <QNetworkAccessManager>
#include <QNetworkReply> #include <QNetworkReply>
#include <QObject> #include <QObject>
#include <QPair> #include <QPair>
#include <QQmlEngine>
#include <QSortFilterProxyModel> #include <QSortFilterProxyModel>
#include <QSslError> #include <QSslError>
#include <QString> #include <QString>
@ -19,11 +21,53 @@
#include <Qt> #include <Qt>
#include <QtGlobal> #include <QtGlobal>
#include <optional>
#include <utility> #include <utility>
using namespace Qt::Literals::StringLiterals; using namespace Qt::Literals::StringLiterals;
class UpgradeableSetting {
Q_GADGET
QML_ANONYMOUS
// NOTE: Unset implies there is neither a value nor a default
enum class State { Unset, Legacy, Modern };
Q_PROPERTY(bool isSet READ isSet )
Q_PROPERTY(bool isLegacy READ isLegacy)
Q_PROPERTY(bool isModern READ isModern)
Q_PROPERTY(QVariant value READ value) // string or null
public:
struct legacy_tag_t { explicit legacy_tag_t() = default; };
static inline constexpr legacy_tag_t legacy_tag = legacy_tag_t();
UpgradeableSetting() : m_state(State::Unset ) {}
UpgradeableSetting(legacy_tag_t, QString value): m_state(State::Legacy), m_value(std::move(value)) {}
UpgradeableSetting( QString value): m_state(State::Modern), m_value(std::move(value)) {}
bool isSet () const { return m_state != State::Unset; }
bool isLegacy() const { return m_state == State::Legacy; }
bool isModern() const { return m_state == State::Modern; }
QVariant value () const { return m_state == State::Unset ? QVariant::fromValue(nullptr) : m_value; }
friend bool operator==(const UpgradeableSetting &a, const UpgradeableSetting &b)
{ return a.m_state == b.m_state && (a.m_state == State::Unset || a.m_value == b.m_value); }
// returns std::nullopt if there is a legacy template or it is not set
std::optional<QString> asModern() const
{
if (m_state == State::Modern)
return m_value;
return std::nullopt;
}
private:
State m_state;
QString m_value;
};
struct ModelInfo { struct ModelInfo {
Q_GADGET Q_GADGET
Q_PROPERTY(QString id READ id WRITE setId) Q_PROPERTY(QString id READ id WRITE setId)
@ -69,8 +113,11 @@ struct ModelInfo {
Q_PROPERTY(int maxGpuLayers READ maxGpuLayers) Q_PROPERTY(int maxGpuLayers READ maxGpuLayers)
Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty) Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty)
Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens) Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens)
Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate) // user-defined chat template and system message must be written through settings because of their legacy compat
Q_PROPERTY(QString systemPrompt READ systemPrompt WRITE setSystemPrompt) Q_PROPERTY(QVariant defaultChatTemplate READ defaultChatTemplate )
Q_PROPERTY(UpgradeableSetting chatTemplate READ chatTemplate )
Q_PROPERTY(QString defaultSystemMessage READ defaultSystemMessage)
Q_PROPERTY(UpgradeableSetting systemMessage READ systemMessage )
Q_PROPERTY(QString chatNamePrompt READ chatNamePrompt WRITE setChatNamePrompt) Q_PROPERTY(QString chatNamePrompt READ chatNamePrompt WRITE setChatNamePrompt)
Q_PROPERTY(QString suggestedFollowUpPrompt READ suggestedFollowUpPrompt WRITE setSuggestedFollowUpPrompt) Q_PROPERTY(QString suggestedFollowUpPrompt READ suggestedFollowUpPrompt WRITE setSuggestedFollowUpPrompt)
Q_PROPERTY(int likes READ likes WRITE setLikes) Q_PROPERTY(int likes READ likes WRITE setLikes)
@ -178,19 +225,22 @@ public:
void setRepeatPenalty(double p); void setRepeatPenalty(double p);
int repeatPenaltyTokens() const; int repeatPenaltyTokens() const;
void setRepeatPenaltyTokens(int t); void setRepeatPenaltyTokens(int t);
QString promptTemplate() const; QVariant defaultChatTemplate() const;
void setPromptTemplate(const QString &t); UpgradeableSetting chatTemplate() const;
QString systemPrompt() const; QString defaultSystemMessage() const;
void setSystemPrompt(const QString &p); UpgradeableSetting systemMessage() const;
QString chatNamePrompt() const; QString chatNamePrompt() const;
void setChatNamePrompt(const QString &p); void setChatNamePrompt(const QString &p);
QString suggestedFollowUpPrompt() const; QString suggestedFollowUpPrompt() const;
void setSuggestedFollowUpPrompt(const QString &p); void setSuggestedFollowUpPrompt(const QString &p);
// Some metadata must be saved to settings because it does not have a meaningful default from some other source.
// This is useful for fields such as name, description, and URL.
// It is true for any models that have not been installed from models.json.
bool shouldSaveMetadata() const; bool shouldSaveMetadata() const;
private: private:
QVariantMap getFields() const; QVariant getField(QLatin1StringView name) const;
QString m_id; QString m_id;
QString m_name; QString m_name;
@ -216,11 +266,13 @@ private:
mutable int m_maxGpuLayers = -1; mutable int m_maxGpuLayers = -1;
double m_repeatPenalty = 1.18; double m_repeatPenalty = 1.18;
int m_repeatPenaltyTokens = 64; int m_repeatPenaltyTokens = 64;
QString m_promptTemplate = "### Human:\n%1\n\n### Assistant:\n"; std::optional<QString> m_chatTemplate;
QString m_systemPrompt = "### System:\nYou are an AI assistant who gives a quality response to whatever humans ask of you.\n\n"; mutable std::optional<QString> m_modelChatTemplate;
QString m_systemMessage;
QString m_chatNamePrompt = "Describe the above conversation in seven words or less."; QString m_chatNamePrompt = "Describe the above conversation in seven words or less.";
QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts."; QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts.";
friend class MySettings; friend class MySettings;
friend class ModelList;
}; };
Q_DECLARE_METATYPE(ModelInfo) Q_DECLARE_METATYPE(ModelInfo)
@ -340,8 +392,8 @@ public:
GpuLayersRole, GpuLayersRole,
RepeatPenaltyRole, RepeatPenaltyRole,
RepeatPenaltyTokensRole, RepeatPenaltyTokensRole,
PromptTemplateRole, ChatTemplateRole,
SystemPromptRole, SystemMessageRole,
ChatNamePromptRole, ChatNamePromptRole,
SuggestedFollowUpPromptRole, SuggestedFollowUpPromptRole,
MinPRole, MinPRole,
@ -394,8 +446,8 @@ public:
roles[GpuLayersRole] = "gpuLayers"; roles[GpuLayersRole] = "gpuLayers";
roles[RepeatPenaltyRole] = "repeatPenalty"; roles[RepeatPenaltyRole] = "repeatPenalty";
roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens"; roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens";
roles[PromptTemplateRole] = "promptTemplate"; roles[ChatTemplateRole] = "chatTemplate";
roles[SystemPromptRole] = "systemPrompt"; roles[SystemMessageRole] = "systemMessage";
roles[ChatNamePromptRole] = "chatNamePrompt"; roles[ChatNamePromptRole] = "chatNamePrompt";
roles[SuggestedFollowUpPromptRole] = "suggestedFollowUpPrompt"; roles[SuggestedFollowUpPromptRole] = "suggestedFollowUpPrompt";
roles[LikesRole] = "likes"; roles[LikesRole] = "likes";
@ -416,7 +468,7 @@ public:
bool contains(const QString &id) const; bool contains(const QString &id) const;
bool containsByFilename(const QString &filename) const; bool containsByFilename(const QString &filename) const;
Q_INVOKABLE ModelInfo modelInfo(const QString &id) const; Q_INVOKABLE ModelInfo modelInfo(const QString &id) const;
Q_INVOKABLE ModelInfo modelInfoByFilename(const QString &filename) const; Q_INVOKABLE ModelInfo modelInfoByFilename(const QString &filename, bool allowClone = true) const;
Q_INVOKABLE bool isUniqueName(const QString &name) const; Q_INVOKABLE bool isUniqueName(const QString &name) const;
Q_INVOKABLE QString clone(const ModelInfo &model); Q_INVOKABLE QString clone(const ModelInfo &model);
Q_INVOKABLE void removeClone(const ModelInfo &model); Q_INVOKABLE void removeClone(const ModelInfo &model);
@ -476,15 +528,18 @@ Q_SIGNALS:
void discoverSortChanged(); void discoverSortChanged();
void discoverProgressChanged(); void discoverProgressChanged();
void discoverInProgressChanged(); void discoverInProgressChanged();
void modelInfoChanged(const ModelInfo &info);
protected: protected:
bool eventFilter(QObject *obj, QEvent *ev) override; bool eventFilter(QObject *obj, QEvent *ev) override;
private Q_SLOTS: private Q_SLOTS:
void onDataChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList<int> &roles);
void resortModel(); void resortModel();
void updateModelsFromJson(); void updateModelsFromJson();
void updateModelsFromJsonAsync(); void updateModelsFromJsonAsync();
void updateModelsFromSettings(); void updateModelsFromSettings();
void maybeUpdateDataForSettings(const ModelInfo &info, bool fromInfo);
void updateDataForSettings(); void updateDataForSettings();
void handleModelsJsonDownloadFinished(); void handleModelsJsonDownloadFinished();
void handleModelsJsonDownloadErrorOccurred(QNetworkReply::NetworkError code); void handleModelsJsonDownloadErrorOccurred(QNetworkReply::NetworkError code);
@ -495,6 +550,9 @@ private Q_SLOTS:
void handleSslErrors(QNetworkReply *reply, const QList<QSslError> &errors); void handleSslErrors(QNetworkReply *reply, const QList<QSslError> &errors);
private: private:
// Return the index of the model with the given id, or -1 if not found.
int indexByModelId(const QString &id) const;
void removeInternal(const ModelInfo &model); void removeInternal(const ModelInfo &model);
void clearDiscoveredModels(); void clearDiscoveredModels();
bool modelExists(const QString &fileName) const; bool modelExists(const QString &fileName) const;

View File

@ -1,5 +1,8 @@
#include "mysettings.h" #include "mysettings.h"
#include "chatllm.h"
#include "modellist.h"
#include <gpt4all-backend/llmodel.h> #include <gpt4all-backend/llmodel.h>
#include <QDebug> #include <QDebug>
@ -29,8 +32,13 @@ static const QStringList suggestionModeNames { "LocalDocsOnly", "On", "Off" };
static const QStringList chatThemeNames { "Light", "Dark", "LegacyDark" }; static const QStringList chatThemeNames { "Light", "Dark", "LegacyDark" };
static const QStringList fontSizeNames { "Small", "Medium", "Large" }; static const QStringList fontSizeNames { "Small", "Medium", "Large" };
// FIXME: All of these default strings that are shown in the UI for settings need to be marked as // psuedo-enum
// translatable namespace ModelSettingsKey { namespace {
auto ChatTemplate = "chatTemplate"_L1;
auto PromptTemplate = "promptTemplate"_L1; // legacy
auto SystemMessage = "systemMessage"_L1;
auto SystemPrompt = "systemPrompt"_L1; // legacy
} } // namespace ModelSettingsKey::(anonymous)
namespace defaults { namespace defaults {
@ -48,7 +56,6 @@ static const QVariantMap basicDefaults {
{ "fontSize", QVariant::fromValue(FontSize::Small) }, { "fontSize", QVariant::fromValue(FontSize::Small) },
{ "lastVersionStarted", "" }, { "lastVersionStarted", "" },
{ "networkPort", 4891, }, { "networkPort", 4891, },
{ "saveChatsContext", false },
{ "systemTray", false }, { "systemTray", false },
{ "serverChat", false }, { "serverChat", false },
{ "userDefaultModel", "Application default" }, { "userDefaultModel", "Application default" },
@ -147,6 +154,11 @@ static QStringList getUiLanguages(const QString &modelPath)
return languageList; return languageList;
} }
static QString modelSettingName(const ModelInfo &info, auto &&name)
{
return u"model-%1/%2"_s.arg(info.id(), name);
}
class MyPrivateSettings: public MySettings { }; class MyPrivateSettings: public MySettings { };
Q_GLOBAL_STATIC(MyPrivateSettings, settingsInstance) Q_GLOBAL_STATIC(MyPrivateSettings, settingsInstance)
MySettings *MySettings::globalInstance() MySettings *MySettings::globalInstance()
@ -162,6 +174,34 @@ MySettings::MySettings()
{ {
} }
QVariant MySettings::checkJinjaTemplateError(const QString &tmpl)
{
if (auto err = ChatLLM::checkJinjaTemplateError(tmpl.toStdString()))
return QString::fromStdString(*err);
return QVariant::fromValue(nullptr);
}
// Unset settings come from ModelInfo. Listen for changes so we can emit our own setting-specific signals.
void MySettings::onModelInfoChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList<int> &roles)
{
auto settingChanged = [&](const auto &info, auto role, const auto &name) {
return (roles.isEmpty() || roles.contains(role)) && !m_settings.contains(modelSettingName(info, name));
};
auto &modelList = dynamic_cast<const ModelList &>(*QObject::sender());
for (int row = topLeft.row(); row <= bottomRight.row(); row++) {
using enum ModelList::Roles;
using namespace ModelSettingsKey;
auto index = topLeft.siblingAtRow(row);
if (auto info = modelList.modelInfo(index.data(IdRole).toString()); !info.id().isNull()) {
if (settingChanged(info, ChatTemplateRole, ChatTemplate))
emit chatTemplateChanged(info, /*fromInfo*/ true);
if (settingChanged(info, SystemMessageRole, SystemMessage))
emit systemMessageChanged(info, /*fromInfo*/ true);
}
}
}
QVariant MySettings::getBasicSetting(const QString &name) const QVariant MySettings::getBasicSetting(const QString &name) const
{ {
return m_settings.value(name, basicDefaults.value(name)); return m_settings.value(name, basicDefaults.value(name));
@ -194,8 +234,8 @@ void MySettings::restoreModelDefaults(const ModelInfo &info)
setModelGpuLayers(info, info.m_gpuLayers); setModelGpuLayers(info, info.m_gpuLayers);
setModelRepeatPenalty(info, info.m_repeatPenalty); setModelRepeatPenalty(info, info.m_repeatPenalty);
setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens); setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens);
setModelPromptTemplate(info, info.m_promptTemplate); resetModelChatTemplate (info);
setModelSystemPrompt(info, info.m_systemPrompt); resetModelSystemMessage(info);
setModelChatNamePrompt(info, info.m_chatNamePrompt); setModelChatNamePrompt(info, info.m_chatNamePrompt);
setModelSuggestedFollowUpPrompt(info, info.m_suggestedFollowUpPrompt); setModelSuggestedFollowUpPrompt(info, info.m_suggestedFollowUpPrompt);
} }
@ -206,7 +246,6 @@ void MySettings::restoreApplicationDefaults()
setFontSize(basicDefaults.value("fontSize").value<FontSize>()); setFontSize(basicDefaults.value("fontSize").value<FontSize>());
setDevice(defaults::device); setDevice(defaults::device);
setThreadCount(defaults::threadCount); setThreadCount(defaults::threadCount);
setSaveChatsContext(basicDefaults.value("saveChatsContext").toBool());
setSystemTray(basicDefaults.value("systemTray").toBool()); setSystemTray(basicDefaults.value("systemTray").toBool());
setServerChat(basicDefaults.value("serverChat").toBool()); setServerChat(basicDefaults.value("serverChat").toBool());
setNetworkPort(basicDefaults.value("networkPort").toInt()); setNetworkPort(basicDefaults.value("networkPort").toInt());
@ -252,29 +291,37 @@ void MySettings::setModelName(const ModelInfo &info, const QString &value, bool
emit nameChanged(info); emit nameChanged(info);
} }
static QString modelSettingName(const ModelInfo &info, const QString &name) QVariant MySettings::getModelSetting(QLatin1StringView name, const ModelInfo &info) const
{ {
return u"model-%1/%2"_s.arg(info.id(), name); QLatin1StringView nameL1(name);
return m_settings.value(modelSettingName(info, nameL1), info.getField(nameL1));
} }
QVariant MySettings::getModelSetting(const QString &name, const ModelInfo &info) const QVariant MySettings::getModelSetting(const char *name, const ModelInfo &info) const
{ {
return m_settings.value(modelSettingName(info, name), info.getFields().value(name)); return getModelSetting(QLatin1StringView(name), info);
} }
void MySettings::setModelSetting(const QString &name, const ModelInfo &info, const QVariant &value, bool force, void MySettings::setModelSetting(QLatin1StringView name, const ModelInfo &info, const QVariant &value, bool force,
bool signal) bool signal)
{ {
if (!force && (info.id().isEmpty() || getModelSetting(name, info) == value)) if (!force && (info.id().isEmpty() || getModelSetting(name, info) == value))
return; return;
QString settingName = modelSettingName(info, name); QLatin1StringView nameL1(name);
if (info.getFields().value(name) == value && !info.shouldSaveMetadata()) QString settingName = modelSettingName(info, nameL1);
if (info.getField(nameL1) == value && !info.shouldSaveMetadata())
m_settings.remove(settingName); m_settings.remove(settingName);
else else
m_settings.setValue(settingName, value); m_settings.setValue(settingName, value);
if (signal && !force) if (signal && !force)
QMetaObject::invokeMethod(this, u"%1Changed"_s.arg(name).toLatin1().constData(), Q_ARG(ModelInfo, info)); QMetaObject::invokeMethod(this, u"%1Changed"_s.arg(nameL1).toLatin1().constData(), Q_ARG(ModelInfo, info));
}
void MySettings::setModelSetting(const char *name, const ModelInfo &info, const QVariant &value, bool force,
bool signal)
{
setModelSetting(QLatin1StringView(name), info, value, force, signal);
} }
QString MySettings::modelFilename (const ModelInfo &info) const { return getModelSetting("filename", info).toString(); } QString MySettings::modelFilename (const ModelInfo &info) const { return getModelSetting("filename", info).toString(); }
@ -297,11 +344,68 @@ int MySettings::modelContextLength (const ModelInfo &info) const
int MySettings::modelGpuLayers (const ModelInfo &info) const { return getModelSetting("gpuLayers", info).toInt(); } int MySettings::modelGpuLayers (const ModelInfo &info) const { return getModelSetting("gpuLayers", info).toInt(); }
double MySettings::modelRepeatPenalty (const ModelInfo &info) const { return getModelSetting("repeatPenalty", info).toDouble(); } double MySettings::modelRepeatPenalty (const ModelInfo &info) const { return getModelSetting("repeatPenalty", info).toDouble(); }
int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); } int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); }
QString MySettings::modelPromptTemplate (const ModelInfo &info) const { return getModelSetting("promptTemplate", info).toString(); }
QString MySettings::modelSystemPrompt (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); }
QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); } QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); }
QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); } QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); }
auto MySettings::getUpgradeableModelSetting(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
) const -> UpgradeableSetting
{
if (info.id().isEmpty()) {
qWarning("%s: got null model", Q_FUNC_INFO);
return {};
}
auto value = m_settings.value(modelSettingName(info, legacyKey));
if (value.isValid())
return { UpgradeableSetting::legacy_tag, value.toString() };
value = getModelSetting(newKey, info);
if (!value.isNull())
return value.toString();
return {}; // neither a default nor an override
}
bool MySettings::isUpgradeableModelSettingSet(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
) const
{
if (info.id().isEmpty()) {
qWarning("%s: got null model", Q_FUNC_INFO);
return false;
}
if (m_settings.contains(modelSettingName(info, legacyKey)))
return true;
// NOTE: unlike getUpgradeableSetting(), this ignores the default
return m_settings.contains(modelSettingName(info, newKey));
}
auto MySettings::modelChatTemplate(const ModelInfo &info) const -> UpgradeableSetting
{
using namespace ModelSettingsKey;
return getUpgradeableModelSetting(info, PromptTemplate, ChatTemplate);
}
bool MySettings::isModelChatTemplateSet(const ModelInfo &info) const
{
using namespace ModelSettingsKey;
return isUpgradeableModelSettingSet(info, PromptTemplate, ChatTemplate);
}
auto MySettings::modelSystemMessage(const ModelInfo &info) const -> UpgradeableSetting
{
using namespace ModelSettingsKey;
return getUpgradeableModelSetting(info, SystemPrompt, SystemMessage);
}
bool MySettings::isModelSystemMessageSet(const ModelInfo &info) const
{
using namespace ModelSettingsKey;
return isUpgradeableModelSettingSet(info, SystemPrompt, SystemMessage);
}
void MySettings::setModelFilename(const ModelInfo &info, const QString &value, bool force) void MySettings::setModelFilename(const ModelInfo &info, const QString &value, bool force)
{ {
setModelSetting("filename", info, value, force, true); setModelSetting("filename", info, value, force, true);
@ -402,14 +506,77 @@ void MySettings::setModelRepeatPenaltyTokens(const ModelInfo &info, int value, b
setModelSetting("repeatPenaltyTokens", info, value, force, true); setModelSetting("repeatPenaltyTokens", info, value, force, true);
} }
void MySettings::setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force) bool MySettings::setUpgradeableModelSetting(
{ const ModelInfo &info, const QString &value, QLatin1StringView legacyKey, QLatin1StringView newKey
setModelSetting("promptTemplate", info, value, force, true); ) {
if (info.id().isEmpty()) {
qWarning("%s: got null model", Q_FUNC_INFO);
return false;
}
auto legacyModelKey = modelSettingName(info, legacyKey);
auto newModelKey = modelSettingName(info, newKey );
bool changed = false;
if (m_settings.contains(legacyModelKey)) {
m_settings.remove(legacyModelKey);
changed = true;
}
auto oldValue = m_settings.value(newModelKey);
if (!oldValue.isValid() || oldValue.toString() != value) {
m_settings.setValue(newModelKey, value);
changed = true;
}
return changed;
} }
void MySettings::setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force) bool MySettings::resetUpgradeableModelSetting(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
) {
if (info.id().isEmpty()) {
qWarning("%s: got null model", Q_FUNC_INFO);
return false;
}
auto legacyModelKey = modelSettingName(info, legacyKey);
auto newModelKey = modelSettingName(info, newKey );
bool changed = false;
if (m_settings.contains(legacyModelKey)) {
m_settings.remove(legacyModelKey);
changed = true;
}
if (m_settings.contains(newModelKey)) {
m_settings.remove(newModelKey);
changed = true;
}
return changed;
}
void MySettings::setModelChatTemplate(const ModelInfo &info, const QString &value)
{ {
setModelSetting("systemPrompt", info, value, force, true); using namespace ModelSettingsKey;
if (setUpgradeableModelSetting(info, value, PromptTemplate, ChatTemplate))
emit chatTemplateChanged(info);
}
void MySettings::resetModelChatTemplate(const ModelInfo &info)
{
using namespace ModelSettingsKey;
if (resetUpgradeableModelSetting(info, PromptTemplate, ChatTemplate))
emit chatTemplateChanged(info);
}
void MySettings::setModelSystemMessage(const ModelInfo &info, const QString &value)
{
using namespace ModelSettingsKey;
if (setUpgradeableModelSetting(info, value, SystemPrompt, SystemMessage))
emit systemMessageChanged(info);
}
void MySettings::resetModelSystemMessage(const ModelInfo &info)
{
using namespace ModelSettingsKey;
if (resetUpgradeableModelSetting(info, SystemPrompt, SystemMessage))
emit systemMessageChanged(info);
} }
void MySettings::setModelChatNamePrompt(const ModelInfo &info, const QString &value, bool force) void MySettings::setModelChatNamePrompt(const ModelInfo &info, const QString &value, bool force)
@ -445,7 +612,6 @@ void MySettings::setThreadCount(int value)
emit threadCountChanged(); emit threadCountChanged();
} }
bool MySettings::saveChatsContext() const { return getBasicSetting("saveChatsContext" ).toBool(); }
bool MySettings::systemTray() const { return getBasicSetting("systemTray" ).toBool(); } bool MySettings::systemTray() const { return getBasicSetting("systemTray" ).toBool(); }
bool MySettings::serverChat() const { return getBasicSetting("serverChat" ).toBool(); } bool MySettings::serverChat() const { return getBasicSetting("serverChat" ).toBool(); }
int MySettings::networkPort() const { return getBasicSetting("networkPort" ).toInt(); } int MySettings::networkPort() const { return getBasicSetting("networkPort" ).toInt(); }
@ -464,7 +630,6 @@ ChatTheme MySettings::chatTheme() const { return ChatTheme (getEnu
FontSize MySettings::fontSize() const { return FontSize (getEnumSetting("fontSize", fontSizeNames)); } FontSize MySettings::fontSize() const { return FontSize (getEnumSetting("fontSize", fontSizeNames)); }
SuggestionMode MySettings::suggestionMode() const { return SuggestionMode(getEnumSetting("suggestionMode", suggestionModeNames)); } SuggestionMode MySettings::suggestionMode() const { return SuggestionMode(getEnumSetting("suggestionMode", suggestionModeNames)); }
void MySettings::setSaveChatsContext(bool value) { setBasicSetting("saveChatsContext", value); }
void MySettings::setSystemTray(bool value) { setBasicSetting("systemTray", value); } void MySettings::setSystemTray(bool value) { setBasicSetting("systemTray", value); }
void MySettings::setServerChat(bool value) { setBasicSetting("serverChat", value); } void MySettings::setServerChat(bool value) { setBasicSetting("serverChat", value); }
void MySettings::setNetworkPort(int value) { setBasicSetting("networkPort", value); } void MySettings::setNetworkPort(int value) { setBasicSetting("networkPort", value); }

View File

@ -4,6 +4,9 @@
#include "modellist.h" // IWYU pragma: keep #include "modellist.h" // IWYU pragma: keep
#include <QDateTime> #include <QDateTime>
#include <QLatin1StringView>
#include <QList>
#include <QModelIndex>
#include <QObject> #include <QObject>
#include <QSettings> #include <QSettings>
#include <QString> #include <QString>
@ -48,7 +51,6 @@ class MySettings : public QObject
{ {
Q_OBJECT Q_OBJECT
Q_PROPERTY(int threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) Q_PROPERTY(int threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool saveChatsContext READ saveChatsContext WRITE setSaveChatsContext NOTIFY saveChatsContextChanged)
Q_PROPERTY(bool systemTray READ systemTray WRITE setSystemTray NOTIFY systemTrayChanged) Q_PROPERTY(bool systemTray READ systemTray WRITE setSystemTray NOTIFY systemTrayChanged)
Q_PROPERTY(bool serverChat READ serverChat WRITE setServerChat NOTIFY serverChatChanged) Q_PROPERTY(bool serverChat READ serverChat WRITE setServerChat NOTIFY serverChatChanged)
Q_PROPERTY(QString modelPath READ modelPath WRITE setModelPath NOTIFY modelPathChanged) Q_PROPERTY(QString modelPath READ modelPath WRITE setModelPath NOTIFY modelPathChanged)
@ -75,9 +77,18 @@ class MySettings : public QObject
Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged) Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged)
Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT) Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT)
private:
explicit MySettings();
~MySettings() override = default;
public Q_SLOTS:
void onModelInfoChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList<int> &roles = {});
public: public:
static MySettings *globalInstance(); static MySettings *globalInstance();
Q_INVOKABLE static QVariant checkJinjaTemplateError(const QString &tmpl);
// Restore methods // Restore methods
Q_INVOKABLE void restoreModelDefaults(const ModelInfo &info); Q_INVOKABLE void restoreModelDefaults(const ModelInfo &info);
Q_INVOKABLE void restoreApplicationDefaults(); Q_INVOKABLE void restoreApplicationDefaults();
@ -125,10 +136,14 @@ public:
Q_INVOKABLE void setModelRepeatPenalty(const ModelInfo &info, double value, bool force = false); Q_INVOKABLE void setModelRepeatPenalty(const ModelInfo &info, double value, bool force = false);
int modelRepeatPenaltyTokens(const ModelInfo &info) const; int modelRepeatPenaltyTokens(const ModelInfo &info) const;
Q_INVOKABLE void setModelRepeatPenaltyTokens(const ModelInfo &info, int value, bool force = false); Q_INVOKABLE void setModelRepeatPenaltyTokens(const ModelInfo &info, int value, bool force = false);
QString modelPromptTemplate(const ModelInfo &info) const; auto modelChatTemplate(const ModelInfo &info) const -> UpgradeableSetting;
Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force = false); Q_INVOKABLE bool isModelChatTemplateSet(const ModelInfo &info) const;
QString modelSystemPrompt(const ModelInfo &info) const; Q_INVOKABLE void setModelChatTemplate(const ModelInfo &info, const QString &value);
Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force = false); Q_INVOKABLE void resetModelChatTemplate(const ModelInfo &info);
auto modelSystemMessage(const ModelInfo &info) const -> UpgradeableSetting;
Q_INVOKABLE bool isModelSystemMessageSet(const ModelInfo &info) const;
Q_INVOKABLE void setModelSystemMessage(const ModelInfo &info, const QString &value);
Q_INVOKABLE void resetModelSystemMessage(const ModelInfo &info);
int modelContextLength(const ModelInfo &info) const; int modelContextLength(const ModelInfo &info) const;
Q_INVOKABLE void setModelContextLength(const ModelInfo &info, int value, bool force = false); Q_INVOKABLE void setModelContextLength(const ModelInfo &info, int value, bool force = false);
int modelGpuLayers(const ModelInfo &info) const; int modelGpuLayers(const ModelInfo &info) const;
@ -141,8 +156,6 @@ public:
// Application settings // Application settings
int threadCount() const; int threadCount() const;
void setThreadCount(int value); void setThreadCount(int value);
bool saveChatsContext() const;
void setSaveChatsContext(bool value);
bool systemTray() const; bool systemTray() const;
void setSystemTray(bool value); void setSystemTray(bool value);
bool serverChat() const; bool serverChat() const;
@ -215,12 +228,11 @@ Q_SIGNALS:
void gpuLayersChanged(const ModelInfo &info); void gpuLayersChanged(const ModelInfo &info);
void repeatPenaltyChanged(const ModelInfo &info); void repeatPenaltyChanged(const ModelInfo &info);
void repeatPenaltyTokensChanged(const ModelInfo &info); void repeatPenaltyTokensChanged(const ModelInfo &info);
void promptTemplateChanged(const ModelInfo &info); void chatTemplateChanged(const ModelInfo &info, bool fromInfo = false);
void systemPromptChanged(const ModelInfo &info); void systemMessageChanged(const ModelInfo &info, bool fromInfo = false);
void chatNamePromptChanged(const ModelInfo &info); void chatNamePromptChanged(const ModelInfo &info);
void suggestedFollowUpPromptChanged(const ModelInfo &info); void suggestedFollowUpPromptChanged(const ModelInfo &info);
void threadCountChanged(); void threadCountChanged();
void saveChatsContextChanged();
void systemTrayChanged(); void systemTrayChanged();
void serverChatChanged(); void serverChatChanged();
void modelPathChanged(); void modelPathChanged();
@ -245,6 +257,30 @@ Q_SIGNALS:
void suggestionModeChanged(); void suggestionModeChanged();
void languageAndLocaleChanged(); void languageAndLocaleChanged();
private:
QVariant getBasicSetting(const QString &name) const;
void setBasicSetting(const QString &name, const QVariant &value, std::optional<QString> signal = std::nullopt);
int getEnumSetting(const QString &setting, const QStringList &valueNames) const;
QVariant getModelSetting(QLatin1StringView name, const ModelInfo &info) const;
QVariant getModelSetting(const char *name, const ModelInfo &info) const;
void setModelSetting(QLatin1StringView name, const ModelInfo &info, const QVariant &value, bool force,
bool signal = false);
void setModelSetting(const char *name, const ModelInfo &info, const QVariant &value, bool force,
bool signal = false);
auto getUpgradeableModelSetting(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
) const -> UpgradeableSetting;
bool isUpgradeableModelSettingSet(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
) const;
bool setUpgradeableModelSetting(
const ModelInfo &info, const QString &value, QLatin1StringView legacyKey, QLatin1StringView newKey
);
bool resetUpgradeableModelSetting(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
);
QString filePathForLocale(const QLocale &locale);
private: private:
QSettings m_settings; QSettings m_settings;
bool m_forceMetal; bool m_forceMetal;
@ -253,18 +289,7 @@ private:
const QStringList m_uiLanguages; const QStringList m_uiLanguages;
std::unique_ptr<QTranslator> m_translator; std::unique_ptr<QTranslator> m_translator;
private:
explicit MySettings();
~MySettings() {}
friend class MyPrivateSettings; friend class MyPrivateSettings;
QVariant getBasicSetting(const QString &name) const;
void setBasicSetting(const QString &name, const QVariant &value, std::optional<QString> signal = std::nullopt);
int getEnumSetting(const QString &setting, const QStringList &valueNames) const;
QVariant getModelSetting(const QString &name, const ModelInfo &info) const;
void setModelSetting(const QString &name, const ModelInfo &info, const QVariant &value, bool force,
bool signal = false);
QString filePathForLocale(const QLocale &locale);
}; };
#endif // MYSETTINGS_H #endif // MYSETTINGS_H

View File

@ -8,6 +8,7 @@
#include "localdocsmodel.h" #include "localdocsmodel.h"
#include "modellist.h" #include "modellist.h"
#include "mysettings.h" #include "mysettings.h"
#include "utils.h"
#include <gpt4all-backend/llmodel.h> #include <gpt4all-backend/llmodel.h>
@ -192,11 +193,14 @@ bool Network::packageAndSendJson(const QString &ingestId, const QString &json)
return false; return false;
} }
auto *currentChat = ChatListModel::globalInstance()->currentChat();
Q_ASSERT(currentChat);
auto modelInfo = currentChat->modelInfo();
Q_ASSERT(doc.isObject()); Q_ASSERT(doc.isObject());
Q_ASSERT(ChatListModel::globalInstance()->currentChat());
QJsonObject object = doc.object(); QJsonObject object = doc.object();
object.insert("source", "gpt4all-chat"); object.insert("source", "gpt4all-chat");
object.insert("agent_id", ChatListModel::globalInstance()->currentChat()->modelInfo().filename()); object.insert("agent_id", modelInfo.filename());
object.insert("submitter_id", m_uniqueId); object.insert("submitter_id", m_uniqueId);
object.insert("ingest_id", ingestId); object.insert("ingest_id", ingestId);
@ -204,8 +208,9 @@ bool Network::packageAndSendJson(const QString &ingestId, const QString &json)
if (!attribution.isEmpty()) if (!attribution.isEmpty())
object.insert("network/attribution", attribution); object.insert("network/attribution", attribution);
QString promptTemplate = ChatListModel::globalInstance()->currentChat()->modelInfo().promptTemplate(); if (!modelInfo.id().isNull())
object.insert("prompt_template", promptTemplate); if (auto tmpl = modelInfo.chatTemplate().asModern())
object.insert("chat_template"_L1, *tmpl);
QJsonDocument newDoc; QJsonDocument newDoc;
newDoc.setObject(object); newDoc.setObject(object);
@ -358,7 +363,8 @@ void Network::sendStartup()
void Network::trackChatEvent(const QString &ev, QVariantMap props) void Network::trackChatEvent(const QString &ev, QVariantMap props)
{ {
const auto &curChat = ChatListModel::globalInstance()->currentChat(); auto *curChat = ChatListModel::globalInstance()->currentChat();
Q_ASSERT(curChat);
if (!props.contains("model")) if (!props.contains("model"))
props.insert("model", curChat->modelInfo().filename()); props.insert("model", curChat->modelInfo().filename());
props.insert("device_backend", curChat->deviceBackend()); props.insert("device_backend", curChat->deviceBackend());
@ -366,7 +372,7 @@ void Network::trackChatEvent(const QString &ev, QVariantMap props)
props.insert("doc_collections_enabled", curChat->collectionList().count()); props.insert("doc_collections_enabled", curChat->collectionList().count());
props.insert("doc_collections_total", LocalDocs::globalInstance()->localDocsModel()->rowCount()); props.insert("doc_collections_total", LocalDocs::globalInstance()->localDocsModel()->rowCount());
props.insert("datalake_active", MySettings::globalInstance()->networkIsActive()); props.insert("datalake_active", MySettings::globalInstance()->networkIsActive());
props.insert("using_server", ChatListModel::globalInstance()->currentChat()->isServer()); props.insert("using_server", curChat->isServer());
trackEvent(ev, props); trackEvent(ev, props);
} }

View File

@ -313,11 +313,8 @@ const std::unordered_map<BaseCompletionRequest::Type, const char *> BaseCompleti
class ChatRequest : public BaseCompletionRequest { class ChatRequest : public BaseCompletionRequest {
public: public:
struct Message { struct Message {
enum class Role : uint8_t { enum class Role { System, User, Assistant };
User, Role role;
Assistant,
};
Role role;
QString content; QString content;
}; };
@ -349,7 +346,6 @@ protected:
this->messages.clear(); this->messages.clear();
{ {
QCborArray arr = value.toArray(); QCborArray arr = value.toArray();
Message::Role nextRole = Message::Role::User;
for (qsizetype i = 0; i < arr.size(); i++) { for (qsizetype i = 0; i < arr.size(); i++) {
const auto &elem = arr[i]; const auto &elem = arr[i];
if (!elem.isMap()) if (!elem.isMap())
@ -360,9 +356,9 @@ protected:
QCborMap msg = elem.toMap(); QCborMap msg = elem.toMap();
Message res; Message res;
QString role = takeValue(msg, "role", String, /*required*/ true).toString(); QString role = takeValue(msg, "role", String, /*required*/ true).toString();
if (role == u"system"_s) if (role == u"system"_s) {
continue; // FIXME(jared): don't ignore these res.role = Message::Role::System;
if (role == u"user"_s) { } else if (role == u"user"_s) {
res.role = Message::Role::User; res.role = Message::Role::User;
} else if (role == u"assistant"_s) { } else if (role == u"assistant"_s) {
res.role = Message::Role::Assistant; res.role = Message::Role::Assistant;
@ -374,13 +370,7 @@ protected:
)); ));
} }
res.content = takeValue(msg, "content", String, /*required*/ true).toString(); res.content = takeValue(msg, "content", String, /*required*/ true).toString();
if (res.role != nextRole)
throw InvalidRequestError(fmt::format(
"Invalid 'messages[{}].role': did not expect '{}' here", i, role
));
this->messages.append(res); this->messages.append(res);
nextRole = res.role == Message::Role::User ? Message::Role::Assistant
: Message::Role::User;
if (!msg.isEmpty()) if (!msg.isEmpty())
throw InvalidRequestError(fmt::format( throw InvalidRequestError(fmt::format(
@ -630,8 +620,7 @@ void Server::start()
}); });
#endif #endif
connect(this, &Server::requestServerNewPromptResponsePair, m_chat, connect(this, &Server::requestResetResponseState, m_chat, &Chat::resetResponseState, Qt::BlockingQueuedConnection);
&Chat::serverNewPromptResponsePair, Qt::BlockingQueuedConnection);
} }
static auto makeError(auto &&...args) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>> static auto makeError(auto &&...args) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
@ -642,6 +631,10 @@ static auto makeError(auto &&...args) -> std::pair<QHttpServerResponse, std::opt
auto Server::handleCompletionRequest(const CompletionRequest &request) auto Server::handleCompletionRequest(const CompletionRequest &request)
-> std::pair<QHttpServerResponse, std::optional<QJsonObject>> -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
{ {
Q_ASSERT(m_chatModel);
auto *mySettings = MySettings::globalInstance();
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo(); ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList(); const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
for (const ModelInfo &info : modelList) { for (const ModelInfo &info : modelList) {
@ -654,10 +647,6 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
} }
} }
// adds prompt/response items to GUI
emit requestServerNewPromptResponsePair(request.prompt); // blocks
resetResponse();
// load the new model if necessary // load the new model if necessary
setShouldBeLoaded(true); setShouldBeLoaded(true);
@ -666,47 +655,55 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
return makeError(QHttpServerResponder::StatusCode::InternalServerError); return makeError(QHttpServerResponder::StatusCode::InternalServerError);
} }
emit requestResetResponseState(); // blocks
qsizetype prevMsgIndex = m_chatModel->count() - 1;
if (prevMsgIndex >= 0)
m_chatModel->updateCurrentResponse(prevMsgIndex, false);
// NB: this resets the context, regardless of whether this model is already loaded // NB: this resets the context, regardless of whether this model is already loaded
if (!loadModel(modelInfo)) { if (!loadModel(modelInfo)) {
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl; std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
return makeError(QHttpServerResponder::StatusCode::InternalServerError); return makeError(QHttpServerResponder::StatusCode::InternalServerError);
} }
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results // add prompt/response items to GUI
const int top_k = modelInfo.topK(); m_chatModel->appendPrompt(request.prompt);
const int n_batch = modelInfo.promptBatchSize(); m_chatModel->appendResponse(prevMsgIndex + 1);
const auto repeat_penalty = float(modelInfo.repeatPenalty());
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
LLModel::PromptContext promptCtx {
.n_predict = request.max_tokens,
.top_k = mySettings->modelTopK(modelInfo),
.top_p = request.top_p,
.min_p = request.min_p,
.temp = request.temperature,
.n_batch = mySettings->modelPromptBatchSize(modelInfo),
.repeat_penalty = float(mySettings->modelRepeatPenalty(modelInfo)),
.repeat_last_n = mySettings->modelRepeatPenaltyTokens(modelInfo),
};
auto promptUtf8 = request.prompt.toUtf8();
int promptTokens = 0; int promptTokens = 0;
int responseTokens = 0; int responseTokens = 0;
QList<QPair<QString, QList<ResultInfo>>> responses; QStringList responses;
for (int i = 0; i < request.n; ++i) { for (int i = 0; i < request.n; ++i) {
if (!promptInternal( PromptResult result;
m_collections, try {
request.prompt, result = promptInternal(std::string_view(promptUtf8.cbegin(), promptUtf8.cend()),
/*promptTemplate*/ u"%1"_s, promptCtx,
request.max_tokens, /*usedLocalDocs*/ false);
top_k, } catch (const std::exception &e) {
request.top_p, emit responseChanged(e.what());
request.min_p, emit responseStopped(0);
request.temperature,
n_batch,
repeat_penalty,
repeat_last_n)) {
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
return makeError(QHttpServerResponder::StatusCode::InternalServerError); return makeError(QHttpServerResponder::StatusCode::InternalServerError);
} }
QString resp = response(/*trim*/ false); QString resp = QString::fromUtf8(result.response);
if (request.echo) if (request.echo)
resp = request.prompt + resp; resp = request.prompt + resp;
responses.append({resp, m_databaseResults}); responses << resp;
if (!promptTokens) if (i == 0)
promptTokens = m_promptTokens; promptTokens = result.promptTokens;
responseTokens += m_promptResponseTokens - m_promptTokens; responseTokens += result.responseTokens;
if (i < request.n - 1)
resetResponse();
} }
QJsonObject responseObject { QJsonObject responseObject {
@ -717,25 +714,13 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
}; };
QJsonArray choices; QJsonArray choices;
{ for (qsizetype i = 0; auto &resp : std::as_const(responses)) {
int index = 0; choices << QJsonObject {
for (const auto &r : responses) { { "text", resp },
QString result = r.first; { "index", i++ },
QList<ResultInfo> infos = r.second; { "logprobs", QJsonValue::Null },
QJsonObject choice { { "finish_reason", responseTokens == request.max_tokens ? "length" : "stop" },
{ "text", result }, };
{ "index", index++ },
{ "logprobs", QJsonValue::Null },
{ "finish_reason", responseTokens == request.max_tokens ? "length" : "stop" },
};
if (MySettings::globalInstance()->localDocsShowReferences()) {
QJsonArray references;
for (const auto &ref : infos)
references.append(resultToJson(ref));
choice.insert("references", references.isEmpty() ? QJsonValue::Null : QJsonValue(references));
}
choices.append(choice);
}
} }
responseObject.insert("choices", choices); responseObject.insert("choices", choices);
@ -751,6 +736,8 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
auto Server::handleChatRequest(const ChatRequest &request) auto Server::handleChatRequest(const ChatRequest &request)
-> std::pair<QHttpServerResponse, std::optional<QJsonObject>> -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
{ {
auto *mySettings = MySettings::globalInstance();
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo(); ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList(); const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
for (const ModelInfo &info : modelList) { for (const ModelInfo &info : modelList) {
@ -771,83 +758,58 @@ auto Server::handleChatRequest(const ChatRequest &request)
return makeError(QHttpServerResponder::StatusCode::InternalServerError); return makeError(QHttpServerResponder::StatusCode::InternalServerError);
} }
emit requestResetResponseState(); // blocks
// NB: this resets the context, regardless of whether this model is already loaded // NB: this resets the context, regardless of whether this model is already loaded
if (!loadModel(modelInfo)) { if (!loadModel(modelInfo)) {
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl; std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
return makeError(QHttpServerResponder::StatusCode::InternalServerError); return makeError(QHttpServerResponder::StatusCode::InternalServerError);
} }
const QString promptTemplate = modelInfo.promptTemplate(); m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
const int top_k = modelInfo.topK();
const int n_batch = modelInfo.promptBatchSize();
const auto repeat_penalty = float(modelInfo.repeatPenalty());
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
int promptTokens = 0; Q_ASSERT(!request.messages.isEmpty());
// adds prompt/response items to GUI
std::vector<ChatItem> chatItems;
for (auto &message : request.messages) {
using enum ChatRequest::Message::Role;
switch (message.role) {
case System: chatItems.emplace_back(ChatItem::system_tag, message.content); break;
case User: chatItems.emplace_back(ChatItem::prompt_tag, message.content); break;
case Assistant: chatItems.emplace_back(ChatItem::response_tag, /*currentResponse*/ false); break;
}
}
m_chatModel->appendResponseWithHistory(chatItems);
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
LLModel::PromptContext promptCtx {
.n_predict = request.max_tokens,
.top_k = mySettings->modelTopK(modelInfo),
.top_p = request.top_p,
.min_p = request.min_p,
.temp = request.temperature,
.n_batch = mySettings->modelPromptBatchSize(modelInfo),
.repeat_penalty = float(mySettings->modelRepeatPenalty(modelInfo)),
.repeat_last_n = mySettings->modelRepeatPenaltyTokens(modelInfo),
};
int promptTokens = 0;
int responseTokens = 0; int responseTokens = 0;
QList<QPair<QString, QList<ResultInfo>>> responses; QList<QPair<QString, QList<ResultInfo>>> responses;
Q_ASSERT(!request.messages.isEmpty());
Q_ASSERT(request.messages.size() % 2 == 1);
for (int i = 0; i < request.messages.size() - 2; i += 2) {
using enum ChatRequest::Message::Role;
auto &user = request.messages[i];
auto &assistant = request.messages[i + 1];
Q_ASSERT(user.role == User);
Q_ASSERT(assistant.role == Assistant);
// adds prompt/response items to GUI
emit requestServerNewPromptResponsePair(user.content); // blocks
resetResponse();
if (!promptInternal(
{},
user.content,
promptTemplate,
request.max_tokens,
top_k,
request.top_p,
request.min_p,
request.temperature,
n_batch,
repeat_penalty,
repeat_last_n,
assistant.content)
) {
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
}
promptTokens += m_promptResponseTokens; // previous responses are part of current prompt
}
QString lastMessage = request.messages.last().content;
// adds prompt/response items to GUI
emit requestServerNewPromptResponsePair(lastMessage); // blocks
resetResponse();
for (int i = 0; i < request.n; ++i) { for (int i = 0; i < request.n; ++i) {
if (!promptInternal( ChatPromptResult result;
m_collections, try {
lastMessage, result = promptInternalChat(m_collections, promptCtx);
promptTemplate, } catch (const std::exception &e) {
request.max_tokens, emit responseChanged(e.what());
top_k, emit responseStopped(0);
request.top_p,
request.min_p,
request.temperature,
n_batch,
repeat_penalty,
repeat_last_n)
) {
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
return makeError(QHttpServerResponder::StatusCode::InternalServerError); return makeError(QHttpServerResponder::StatusCode::InternalServerError);
} }
responses.append({response(), m_databaseResults}); responses.emplace_back(result.response, result.databaseResults);
// FIXME(jared): these are UI counts and do not include framing tokens, which they should
if (i == 0) if (i == 0)
promptTokens += m_promptTokens; promptTokens = result.promptTokens;
responseTokens += m_promptResponseTokens - m_promptTokens; responseTokens += result.responseTokens;
if (i != request.n - 1)
resetResponse();
} }
QJsonObject responseObject { QJsonObject responseObject {

View File

@ -33,7 +33,7 @@ public Q_SLOTS:
void start(); void start();
Q_SIGNALS: Q_SIGNALS:
void requestServerNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments = {}); void requestResetResponseState();
private: private:
auto handleCompletionRequest(const CompletionRequest &request) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>; auto handleCompletionRequest(const CompletionRequest &request) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>;

View File

@ -3,23 +3,41 @@
#include <fmt/base.h> #include <fmt/base.h>
#include <fmt/format.h> #include <fmt/format.h>
#include <QByteArray>
#include <QJsonValue>
#include <QLatin1StringView>
#include <QString> #include <QString>
#include <QStringView>
#include <QUtf8StringView>
#include <QVariant> #include <QVariant>
#include <string> #include <initializer_list>
#include <string_view>
#include <utility>
class QJsonObject;
// fmtlib formatters for QString and QVariant // fmtlib formatters for QString and QVariant
#define MAKE_FORMATTER(type, conversion) \ #define MAKE_FORMATTER(type, conversion) \
template <> \ template <> \
struct fmt::formatter<type, char>: fmt::formatter<std::string, char> { \ struct fmt::formatter<type, char>: fmt::formatter<std::string_view, char> { \
template <typename FmtContext> \ template <typename FmtContext> \
FmtContext::iterator format(const type &value, FmtContext &ctx) const \ FmtContext::iterator format(const type &value, FmtContext &ctx) const \
{ \ { \
return formatter<std::string, char>::format(conversion, ctx); \ auto valueUtf8 = (conversion); \
} \ std::string_view view(valueUtf8.cbegin(), valueUtf8.cend()); \
return formatter<std::string_view, char>::format(view, ctx); \
} \
} }
MAKE_FORMATTER(QString, value.toStdString() ); MAKE_FORMATTER(QUtf8StringView, value );
MAKE_FORMATTER(QVariant, value.toString().toStdString()); MAKE_FORMATTER(QStringView, value.toUtf8() );
MAKE_FORMATTER(QString, value.toUtf8() );
MAKE_FORMATTER(QVariant, value.toString().toUtf8());
// alternative to QJsonObject's initializer_list constructor that accepts Latin-1 strings
QJsonObject makeJsonObject(std::initializer_list<std::pair<QLatin1StringView, QJsonValue>> args);
#include "utils.inl"

View File

@ -0,0 +1,9 @@
#include <QJsonObject>
inline QJsonObject makeJsonObject(std::initializer_list<std::pair<QLatin1StringView, QJsonValue>> args)
{
QJsonObject obj;
for (auto &arg : args)
obj.insert(arg.first, arg.second);
return obj;
}

View File

@ -203,11 +203,10 @@ EXPECTED_MODEL_INFO = {
EXPECTED_COMPLETIONS_RESPONSE = { EXPECTED_COMPLETIONS_RESPONSE = {
'choices': [ 'choices': [
{ {
'finish_reason': 'stop', 'finish_reason': 'length',
'index': 0, 'index': 0,
'logprobs': None, 'logprobs': None,
'references': None, 'text': ' jumps over the lazy dog.\n',
'text': ' jumps over the lazy dog.',
}, },
], ],
'id': 'placeholder', 'id': 'placeholder',
@ -242,18 +241,14 @@ def test_with_models(chat_server_with_model: None) -> None:
'type': 'invalid_request_error', 'type': 'invalid_request_error',
}} }}
data = { data = dict(
'model': 'Llama 3.2 1B Instruct', model = 'Llama 3.2 1B Instruct',
'prompt': 'The quick brown fox', prompt = 'The quick brown fox',
'temperature': 0, temperature = 0,
} max_tokens = 6,
)
response = request.post('completions', data=data) response = request.post('completions', data=data)
assert len(response['choices']) == 1 del response['created'] # Remove the dynamic field for comparison
assert response['choices'][0].keys() == {'text', 'index', 'logprobs', 'references', 'finish_reason'}
assert response['choices'][0]['text'] == ' jumps over the lazy dog.'
assert 'created' in response
response.pop('created') # Remove the dynamic field for comparison
assert response == EXPECTED_COMPLETIONS_RESPONSE assert response == EXPECTED_COMPLETIONS_RESPONSE