mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-23 05:58:48 +00:00
Remove binary state from high-level API and use Jinja templates (#3147)
Signed-off-by: Jared Van Bortel <jared@nomic.ai> Signed-off-by: Adam Treat <treat.adam@gmail.com> Co-authored-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
parent
3320094d29
commit
225bf6be93
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -17,3 +17,6 @@
|
|||||||
[submodule "gpt4all-chat/deps/QXlsx"]
|
[submodule "gpt4all-chat/deps/QXlsx"]
|
||||||
path = gpt4all-chat/deps/QXlsx
|
path = gpt4all-chat/deps/QXlsx
|
||||||
url = https://github.com/nomic-ai/QXlsx.git
|
url = https://github.com/nomic-ai/QXlsx.git
|
||||||
|
[submodule "gpt4all-chat/deps/Jinja2Cpp"]
|
||||||
|
path = gpt4all-chat/deps/Jinja2Cpp
|
||||||
|
url = https://github.com/nomic-ai/jinja2cpp.git
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <expected>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <span>
|
#include <span>
|
||||||
@ -24,6 +25,10 @@ using namespace std::string_literals;
|
|||||||
class LLModel {
|
class LLModel {
|
||||||
public:
|
public:
|
||||||
using Token = int32_t;
|
using Token = int32_t;
|
||||||
|
using PromptCallback = std::function<bool(std::span<const Token> batch, bool cached)>;
|
||||||
|
using ResponseCallback = std::function<bool(Token token, std::string_view piece)>;
|
||||||
|
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
|
||||||
|
using ProgressCallback = std::function<bool(float progress)>;
|
||||||
|
|
||||||
class BadArchError: public std::runtime_error {
|
class BadArchError: public std::runtime_error {
|
||||||
public:
|
public:
|
||||||
@ -101,6 +106,7 @@ public:
|
|||||||
static int32_t maxContextLength(const std::string &modelPath);
|
static int32_t maxContextLength(const std::string &modelPath);
|
||||||
static int32_t layerCount(const std::string &modelPath);
|
static int32_t layerCount(const std::string &modelPath);
|
||||||
static bool isEmbeddingModel(const std::string &modelPath);
|
static bool isEmbeddingModel(const std::string &modelPath);
|
||||||
|
static auto chatTemplate(const char *modelPath) -> std::expected<std::string, std::string>;
|
||||||
static void setImplementationsSearchPath(const std::string &path);
|
static void setImplementationsSearchPath(const std::string &path);
|
||||||
static const std::string &implementationsSearchPath();
|
static const std::string &implementationsSearchPath();
|
||||||
static bool hasSupportedCPU();
|
static bool hasSupportedCPU();
|
||||||
@ -124,7 +130,6 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct PromptContext {
|
struct PromptContext {
|
||||||
int32_t n_past = 0; // number of tokens in past conversation
|
|
||||||
int32_t n_predict = 200;
|
int32_t n_predict = 200;
|
||||||
int32_t top_k = 40;
|
int32_t top_k = 40;
|
||||||
float top_p = 0.9f;
|
float top_p = 0.9f;
|
||||||
@ -136,8 +141,6 @@ public:
|
|||||||
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
|
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
|
||||||
};
|
};
|
||||||
|
|
||||||
using ProgressCallback = std::function<bool(float progress)>;
|
|
||||||
|
|
||||||
explicit LLModel() {}
|
explicit LLModel() {}
|
||||||
virtual ~LLModel() {}
|
virtual ~LLModel() {}
|
||||||
|
|
||||||
@ -154,16 +157,12 @@ public:
|
|||||||
|
|
||||||
// This method requires the model to return true from supportsCompletion otherwise it will throw
|
// This method requires the model to return true from supportsCompletion otherwise it will throw
|
||||||
// an error
|
// an error
|
||||||
virtual void prompt(const std::string &prompt,
|
virtual void prompt(std::string_view prompt,
|
||||||
const std::string &promptTemplate,
|
const PromptCallback &promptCallback,
|
||||||
std::function<bool(int32_t)> promptCallback,
|
const ResponseCallback &responseCallback,
|
||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
const PromptContext &ctx);
|
||||||
bool allowContextShift,
|
|
||||||
PromptContext &ctx,
|
|
||||||
bool special = false,
|
|
||||||
std::optional<std::string_view> fakeReply = {});
|
|
||||||
|
|
||||||
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
|
virtual int32_t countPromptTokens(std::string_view prompt) const;
|
||||||
|
|
||||||
virtual size_t embeddingSize() const {
|
virtual size_t embeddingSize() const {
|
||||||
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
|
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
|
||||||
@ -209,23 +208,22 @@ public:
|
|||||||
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }
|
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }
|
||||||
|
|
||||||
virtual int32_t contextLength() const = 0;
|
virtual int32_t contextLength() const = 0;
|
||||||
|
virtual auto specialTokens() -> std::unordered_map<std::string, std::string> const = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||||
// 'prompt' above calls these functions
|
// 'prompt' above calls these functions
|
||||||
virtual std::vector<Token> tokenize(std::string_view str, bool special = false) = 0;
|
virtual std::vector<Token> tokenize(std::string_view str) const = 0;
|
||||||
virtual bool isSpecialToken(Token id) const = 0;
|
virtual bool isSpecialToken(Token id) const = 0;
|
||||||
virtual std::string tokenToString(Token id) const = 0;
|
virtual std::string tokenToString(Token id) const = 0;
|
||||||
virtual void initSampler(PromptContext &ctx) = 0;
|
virtual void initSampler(const PromptContext &ctx) = 0;
|
||||||
virtual Token sampleToken() const = 0;
|
virtual Token sampleToken() const = 0;
|
||||||
virtual bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const = 0;
|
virtual bool evalTokens(int32_t nPast, std::span<const Token> tokens) const = 0;
|
||||||
virtual void shiftContext(PromptContext &promptCtx) = 0;
|
virtual void shiftContext(const PromptContext &promptCtx, int32_t *nPast) = 0;
|
||||||
virtual int32_t inputLength() const = 0;
|
virtual int32_t inputLength() const = 0;
|
||||||
virtual void setTokenizeInputPosition(int32_t pos) = 0;
|
virtual int32_t computeModelInputPosition(std::span<const Token> input) const = 0;
|
||||||
virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
|
virtual void setModelInputPosition(int32_t pos) = 0;
|
||||||
-> std::vector<Token>::const_iterator = 0;
|
virtual void appendInputToken(Token tok) = 0;
|
||||||
virtual void setModelInputPosition(PromptContext &ctx, int32_t pos) = 0;
|
|
||||||
virtual void appendInputToken(PromptContext &ctx, Token tok) = 0;
|
|
||||||
virtual std::span<const Token> inputTokens() const = 0;
|
virtual std::span<const Token> inputTokens() const = 0;
|
||||||
virtual const std::vector<Token> &endTokens() const = 0;
|
virtual const std::vector<Token> &endTokens() const = 0;
|
||||||
virtual bool shouldAddBOS() const = 0;
|
virtual bool shouldAddBOS() const = 0;
|
||||||
@ -242,6 +240,12 @@ protected:
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual auto chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string>
|
||||||
|
{
|
||||||
|
(void)modelPath;
|
||||||
|
return std::unexpected("not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
const Implementation *m_implementation = nullptr;
|
const Implementation *m_implementation = nullptr;
|
||||||
|
|
||||||
ProgressCallback m_progressCallback;
|
ProgressCallback m_progressCallback;
|
||||||
@ -253,19 +257,15 @@ protected:
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool decodePrompt(std::function<bool(int32_t)> promptCallback,
|
// prefill context with prompt
|
||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
auto decodePrompt(const PromptCallback &promptCallback,
|
||||||
bool allowContextShift,
|
const PromptContext &promptCtx,
|
||||||
PromptContext &promptCtx,
|
std::vector<Token> embd_inp)
|
||||||
std::vector<Token> embd_inp,
|
-> std::optional<int32_t>;
|
||||||
bool isResponse = false,
|
// generate a response
|
||||||
bool alwaysDecode = false);
|
void generateResponse(const ResponseCallback &responseCallback,
|
||||||
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
const PromptContext &promptCtx,
|
||||||
bool allowContextShift,
|
int32_t nPast);
|
||||||
PromptContext &promptCtx);
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Token m_tokenize_last_token = -1; // not serialized
|
|
||||||
|
|
||||||
friend class LLMImplementation;
|
friend class LLMImplementation;
|
||||||
};
|
};
|
||||||
|
@ -35,16 +35,15 @@ typedef int32_t token_t;
|
|||||||
* behavior.
|
* behavior.
|
||||||
*/
|
*/
|
||||||
struct llmodel_prompt_context {
|
struct llmodel_prompt_context {
|
||||||
int32_t n_past; // number of tokens in past conversation
|
|
||||||
int32_t n_predict; // number of tokens to predict
|
int32_t n_predict; // number of tokens to predict
|
||||||
int32_t top_k; // top k logits to sample from
|
int32_t top_k; // top k logits to sample from
|
||||||
float top_p; // nucleus sampling probability threshold
|
float top_p; // nucleus sampling probability threshold
|
||||||
float min_p; // Min P sampling
|
float min_p; // Min P sampling
|
||||||
float temp; // temperature to adjust model's output distribution
|
float temp; // temperature to adjust model's output distribution
|
||||||
int32_t n_batch; // number of predictions to generate in parallel
|
int32_t n_batch; // number of predictions to generate in parallel
|
||||||
float repeat_penalty; // penalty factor for repeated tokens
|
float repeat_penalty; // penalty factor for repeated tokens
|
||||||
int32_t repeat_last_n; // last n tokens to penalize
|
int32_t repeat_last_n; // last n tokens to penalize
|
||||||
float context_erase; // percent of context to erase if we exceed the context window
|
float context_erase; // percent of context to erase if we exceed the context window
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llmodel_gpu_device {
|
struct llmodel_gpu_device {
|
||||||
@ -63,10 +62,12 @@ typedef struct llmodel_gpu_device llmodel_gpu_device;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Callback type for prompt processing.
|
* Callback type for prompt processing.
|
||||||
* @param token_id The token id of the prompt.
|
* @param token_ids An array of token ids of the prompt.
|
||||||
|
* @param n_token_ids The number of tokens in the array.
|
||||||
|
* @param cached Whether the tokens were already in cache.
|
||||||
* @return a bool indicating whether the model should keep processing.
|
* @return a bool indicating whether the model should keep processing.
|
||||||
*/
|
*/
|
||||||
typedef bool (*llmodel_prompt_callback)(int32_t token_id);
|
typedef bool (*llmodel_prompt_callback)(const token_t *token_ids, size_t n_token_ids, bool cached);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Callback type for response.
|
* Callback type for response.
|
||||||
@ -74,7 +75,7 @@ typedef bool (*llmodel_prompt_callback)(int32_t token_id);
|
|||||||
* @param response The response string. NOTE: a token_id of -1 indicates the string is an error string.
|
* @param response The response string. NOTE: a token_id of -1 indicates the string is an error string.
|
||||||
* @return a bool indicating whether the model should keep generating.
|
* @return a bool indicating whether the model should keep generating.
|
||||||
*/
|
*/
|
||||||
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
|
typedef bool (*llmodel_response_callback)(token_t token_id, const char *response);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Embedding cancellation callback for use with llmodel_embed.
|
* Embedding cancellation callback for use with llmodel_embed.
|
||||||
@ -85,6 +86,8 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response
|
|||||||
*/
|
*/
|
||||||
typedef bool (*llmodel_emb_cancel_callback)(unsigned *batch_sizes, unsigned n_batch, const char *backend);
|
typedef bool (*llmodel_emb_cancel_callback)(unsigned *batch_sizes, unsigned n_batch, const char *backend);
|
||||||
|
|
||||||
|
typedef void (*llmodel_special_token_callback)(const char *name, const char *token);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a llmodel instance.
|
* Create a llmodel instance.
|
||||||
* Recognises correct model type from file at model_path
|
* Recognises correct model type from file at model_path
|
||||||
@ -183,22 +186,17 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6
|
|||||||
* Generate a response using the model.
|
* Generate a response using the model.
|
||||||
* @param model A pointer to the llmodel_model instance.
|
* @param model A pointer to the llmodel_model instance.
|
||||||
* @param prompt A string representing the input prompt.
|
* @param prompt A string representing the input prompt.
|
||||||
* @param prompt_template A string representing the input prompt template.
|
|
||||||
* @param prompt_callback A callback function for handling the processing of prompt.
|
* @param prompt_callback A callback function for handling the processing of prompt.
|
||||||
* @param response_callback A callback function for handling the generated response.
|
* @param response_callback A callback function for handling the generated response.
|
||||||
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
|
|
||||||
* @param special True if special tokens in the prompt should be processed, false otherwise.
|
|
||||||
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
|
|
||||||
* @param ctx A pointer to the llmodel_prompt_context structure.
|
* @param ctx A pointer to the llmodel_prompt_context structure.
|
||||||
|
* @param error A pointer to a string; will only be set on error.
|
||||||
*/
|
*/
|
||||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
bool llmodel_prompt(llmodel_model model,
|
||||||
const char *prompt_template,
|
const char *prompt,
|
||||||
llmodel_prompt_callback prompt_callback,
|
llmodel_prompt_callback prompt_callback,
|
||||||
llmodel_response_callback response_callback,
|
llmodel_response_callback response_callback,
|
||||||
bool allow_context_shift,
|
llmodel_prompt_context *ctx,
|
||||||
llmodel_prompt_context *ctx,
|
const char **error);
|
||||||
bool special,
|
|
||||||
const char *fake_reply);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate an embedding using the model.
|
* Generate an embedding using the model.
|
||||||
@ -310,6 +308,10 @@ const char *llmodel_model_backend_name(llmodel_model model);
|
|||||||
*/
|
*/
|
||||||
const char *llmodel_model_gpu_device_name(llmodel_model model);
|
const char *llmodel_model_gpu_device_name(llmodel_model model);
|
||||||
|
|
||||||
|
int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error);
|
||||||
|
|
||||||
|
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -202,7 +202,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const
|
|||||||
if (keyidx != -1) {
|
if (keyidx != -1) {
|
||||||
value = gguf_get_val_u32(ctx, keyidx);
|
value = gguf_get_val_u32(ctx, keyidx);
|
||||||
} else {
|
} else {
|
||||||
std::cerr << __func__ << ": " << key << "not found in " << modelPath << "\n";
|
std::cerr << __func__ << ": " << key << " not found in " << modelPath << "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -518,18 +518,13 @@ size_t LLamaModel::restoreState(std::span<const uint8_t> state, std::span<const
|
|||||||
return bytesRead;
|
return bytesRead;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str, bool special)
|
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str) const
|
||||||
{
|
{
|
||||||
bool atStart = m_tokenize_last_token == -1;
|
|
||||||
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
|
|
||||||
std::vector<LLModel::Token> fres(str.length() + 4);
|
std::vector<LLModel::Token> fres(str.length() + 4);
|
||||||
int32_t fres_len = llama_tokenize_gpt4all(
|
int32_t fres_len = llama_tokenize(
|
||||||
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
|
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ true, /*parse_special*/ true
|
||||||
/*parse_special*/ special, /*insert_space*/ insertSpace
|
|
||||||
);
|
);
|
||||||
fres.resize(fres_len);
|
fres.resize(fres_len);
|
||||||
if (fres_len)
|
|
||||||
m_tokenize_last_token = fres.back();
|
|
||||||
return fres;
|
return fres;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -555,7 +550,7 @@ std::string LLamaModel::tokenToString(Token id) const
|
|||||||
return std::string(result.data(), result.size());
|
return std::string(result.data(), result.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLamaModel::initSampler(PromptContext &promptCtx)
|
void LLamaModel::initSampler(const PromptContext &promptCtx)
|
||||||
{
|
{
|
||||||
auto *model = d_ptr->model;
|
auto *model = d_ptr->model;
|
||||||
auto *chain = d_ptr->sampler_chain;
|
auto *chain = d_ptr->sampler_chain;
|
||||||
@ -601,9 +596,11 @@ LLModel::Token LLamaModel::sampleToken() const
|
|||||||
return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1);
|
return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) const
|
bool LLamaModel::evalTokens(int32_t nPast, std::span<const Token> tokens) const
|
||||||
{
|
{
|
||||||
llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1);
|
assert(!tokens.empty());
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm(d_ptr->ctx, 0, nPast, -1);
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
||||||
|
|
||||||
@ -611,7 +608,7 @@ bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) c
|
|||||||
|
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||||
batch.token [i] = tokens[i];
|
batch.token [i] = tokens[i];
|
||||||
batch.pos [i] = ctx.n_past + i;
|
batch.pos [i] = nPast + i;
|
||||||
batch.n_seq_id[i] = 1;
|
batch.n_seq_id[i] = 1;
|
||||||
batch.seq_id [i][0] = 0;
|
batch.seq_id [i][0] = 0;
|
||||||
batch.logits [i] = false;
|
batch.logits [i] = false;
|
||||||
@ -625,13 +622,13 @@ bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) c
|
|||||||
return res == 0;
|
return res == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLamaModel::shiftContext(PromptContext &promptCtx)
|
void LLamaModel::shiftContext(const PromptContext &promptCtx, int32_t *nPast)
|
||||||
{
|
{
|
||||||
// infinite text generation via context shifting
|
// infinite text generation via context shifting
|
||||||
|
|
||||||
// erase up to n_ctx*contextErase tokens
|
// erase up to n_ctx*contextErase tokens
|
||||||
int n_keep = shouldAddBOS();
|
int n_keep = shouldAddBOS();
|
||||||
int n_past = promptCtx.n_past;
|
int n_past = *nPast;
|
||||||
int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase));
|
int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase));
|
||||||
|
|
||||||
assert(n_discard > 0);
|
assert(n_discard > 0);
|
||||||
@ -647,7 +644,7 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
|
|||||||
|
|
||||||
auto &inp = d_ptr->inputTokens;
|
auto &inp = d_ptr->inputTokens;
|
||||||
inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard);
|
inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard);
|
||||||
promptCtx.n_past = inp.size();
|
*nPast = inp.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t LLamaModel::contextLength() const
|
int32_t LLamaModel::contextLength() const
|
||||||
@ -655,39 +652,37 @@ int32_t LLamaModel::contextLength() const
|
|||||||
return llama_n_ctx(d_ptr->ctx);
|
return llama_n_ctx(d_ptr->ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto LLamaModel::specialTokens() -> std::unordered_map<std::string, std::string> const
|
||||||
|
{
|
||||||
|
if (!d_ptr->model)
|
||||||
|
throw std::logic_error("model not loaded");
|
||||||
|
|
||||||
|
std::unordered_map<std::string, std::string> tokens;
|
||||||
|
if (auto id = llama_token_bos(d_ptr->model); id != LLAMA_TOKEN_NULL)
|
||||||
|
tokens.emplace("bos_token", tokenToString(id));
|
||||||
|
if (auto id = llama_token_eos(d_ptr->model); id != LLAMA_TOKEN_NULL)
|
||||||
|
tokens.emplace("eos_token", tokenToString(id));
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
int32_t LLamaModel::inputLength() const
|
int32_t LLamaModel::inputLength() const
|
||||||
{
|
{
|
||||||
return d_ptr->inputTokens.size();
|
return d_ptr->inputTokens.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLamaModel::setTokenizeInputPosition(int32_t pos)
|
int32_t LLamaModel::computeModelInputPosition(std::span<const Token> input) const
|
||||||
{
|
{
|
||||||
assert(pos >= 0);
|
|
||||||
m_tokenize_last_token = pos ? d_ptr->inputTokens.at(size_t(pos) - 1) : -1; // not serialized
|
|
||||||
}
|
|
||||||
|
|
||||||
auto LLamaModel::computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
|
|
||||||
-> std::vector<Token>::const_iterator
|
|
||||||
{
|
|
||||||
assert(ctx.n_past >= 0);
|
|
||||||
auto pos = size_t(ctx.n_past);
|
|
||||||
if (pos > d_ptr->inputTokens.size()) {
|
|
||||||
std::ostringstream ss;
|
|
||||||
ss << "n_past=" << pos << " is past end of token cache length=" << d_ptr->inputTokens.size();
|
|
||||||
throw std::out_of_range(ss.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// find common prefix
|
// find common prefix
|
||||||
auto cacheIt = d_ptr->inputTokens.begin();
|
auto cacheIt = d_ptr->inputTokens.begin();
|
||||||
auto inputIt = input.begin();
|
auto inputIt = input.begin();
|
||||||
while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) {
|
while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) {
|
||||||
++cacheIt; ++inputIt; ++pos;
|
++cacheIt; ++inputIt;
|
||||||
}
|
}
|
||||||
// tell the caller to ignore the tokens between [begin, inputIt)
|
// tell the caller to ignore the tokens between [begin, inputIt)
|
||||||
return inputIt;
|
return inputIt - input.begin();
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
|
void LLamaModel::setModelInputPosition(int32_t pos)
|
||||||
{
|
{
|
||||||
auto &inp = d_ptr->inputTokens;
|
auto &inp = d_ptr->inputTokens;
|
||||||
assert(pos >= 0);
|
assert(pos >= 0);
|
||||||
@ -695,13 +690,11 @@ void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
|
|||||||
// truncate token cache to end at the new n_past
|
// truncate token cache to end at the new n_past
|
||||||
if (pos < inp.size())
|
if (pos < inp.size())
|
||||||
inp.resize(pos);
|
inp.resize(pos);
|
||||||
ctx.n_past = pos;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLamaModel::appendInputToken(PromptContext &ctx, Token tok)
|
void LLamaModel::appendInputToken(Token tok)
|
||||||
{
|
{
|
||||||
d_ptr->inputTokens.push_back(tok);
|
d_ptr->inputTokens.push_back(tok);
|
||||||
ctx.n_past += 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto LLamaModel::inputTokens() const -> std::span<const Token>
|
auto LLamaModel::inputTokens() const -> std::span<const Token>
|
||||||
@ -729,6 +722,37 @@ int32_t LLamaModel::layerCount(std::string const &modelPath) const
|
|||||||
return get_arch_key_u32(modelPath, "block_count");
|
return get_arch_key_u32(modelPath, "block_count");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(jared): reduce redundant code and operations by combining all metadata getters for unloaded
|
||||||
|
// models into a class that keeps the model file open
|
||||||
|
auto LLamaModel::chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string>
|
||||||
|
{
|
||||||
|
auto *ctx = load_gguf(modelPath);
|
||||||
|
if (!ctx)
|
||||||
|
return std::unexpected("failed to open model file");
|
||||||
|
|
||||||
|
std::expected<std::string, std::string> result;
|
||||||
|
enum gguf_type ktype;
|
||||||
|
const int kid = gguf_find_key(ctx, "tokenizer.chat_template");
|
||||||
|
if (kid == -1) {
|
||||||
|
result = std::unexpected("key not found");
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
ktype = gguf_get_kv_type(ctx, kid);
|
||||||
|
if (ktype != GGUF_TYPE_STRING) {
|
||||||
|
result = std::unexpected(
|
||||||
|
"expected key type STRING (" + std::to_string(GGUF_TYPE_STRING) + "), got " + std::to_string(ktype)
|
||||||
|
);
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
result = gguf_get_val_str(ctx, kid);
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
gguf_free(ctx);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_VULKAN
|
#ifdef GGML_USE_VULKAN
|
||||||
static const char *getVulkanVendorName(uint32_t vendorID)
|
static const char *getVulkanVendorName(uint32_t vendorID)
|
||||||
{
|
{
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
struct LLamaPrivate;
|
struct LLamaPrivate;
|
||||||
struct EmbModelSpec;
|
struct EmbModelSpec;
|
||||||
@ -49,26 +50,26 @@ public:
|
|||||||
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
|
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
|
||||||
|
|
||||||
int32_t contextLength() const override;
|
int32_t contextLength() const override;
|
||||||
|
auto specialTokens() -> std::unordered_map<std::string, std::string> const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<Token> tokenize(std::string_view str, bool special) override;
|
std::vector<Token> tokenize(std::string_view str) const override;
|
||||||
bool isSpecialToken(Token id) const override;
|
bool isSpecialToken(Token id) const override;
|
||||||
std::string tokenToString(Token id) const override;
|
std::string tokenToString(Token id) const override;
|
||||||
void initSampler(PromptContext &ctx) override;
|
void initSampler(const PromptContext &ctx) override;
|
||||||
Token sampleToken() const override;
|
Token sampleToken() const override;
|
||||||
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override;
|
bool evalTokens(int32_t nPast, std::span<const Token> tokens) const override;
|
||||||
void shiftContext(PromptContext &promptCtx) override;
|
void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override;
|
||||||
int32_t inputLength() const override;
|
int32_t inputLength() const override;
|
||||||
void setTokenizeInputPosition(int32_t pos) override;
|
int32_t computeModelInputPosition(std::span<const Token> input) const override;
|
||||||
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
|
void setModelInputPosition(int32_t pos) override;
|
||||||
-> std::vector<Token>::const_iterator override;
|
void appendInputToken(Token tok) override;
|
||||||
void setModelInputPosition(PromptContext &ctx, int32_t pos) override;
|
|
||||||
void appendInputToken(PromptContext &ctx, Token tok) override;
|
|
||||||
std::span<const Token> inputTokens() const override;
|
std::span<const Token> inputTokens() const override;
|
||||||
const std::vector<Token> &endTokens() const override;
|
const std::vector<Token> &endTokens() const override;
|
||||||
bool shouldAddBOS() const override;
|
bool shouldAddBOS() const override;
|
||||||
int32_t maxContextLength(std::string const &modelPath) const override;
|
int32_t maxContextLength(std::string const &modelPath) const override;
|
||||||
int32_t layerCount(std::string const &modelPath) const override;
|
int32_t layerCount(std::string const &modelPath) const override;
|
||||||
|
auto chatTemplate(const char *modelPath) const -> std::expected<std::string, std::string> override;
|
||||||
|
|
||||||
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
|
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
|
||||||
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,
|
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,
|
||||||
|
@ -326,6 +326,12 @@ bool LLModel::Implementation::isEmbeddingModel(const std::string &modelPath)
|
|||||||
return llama && llama->isEmbeddingModel(modelPath);
|
return llama && llama->isEmbeddingModel(modelPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto LLModel::Implementation::chatTemplate(const char *modelPath) -> std::expected<std::string, std::string>
|
||||||
|
{
|
||||||
|
auto *llama = constructGlobalLlama();
|
||||||
|
return llama ? llama->chatTemplate(modelPath) : std::unexpected("backend not available");
|
||||||
|
}
|
||||||
|
|
||||||
void LLModel::Implementation::setImplementationsSearchPath(const std::string& path)
|
void LLModel::Implementation::setImplementationsSearchPath(const std::string& path)
|
||||||
{
|
{
|
||||||
s_implementations_search_path = path;
|
s_implementations_search_path = path;
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <functional>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
@ -22,7 +21,6 @@ static_assert(sizeof(token_t) == sizeof(LLModel::Token));
|
|||||||
|
|
||||||
struct LLModelWrapper {
|
struct LLModelWrapper {
|
||||||
LLModel *llModel = nullptr;
|
LLModel *llModel = nullptr;
|
||||||
LLModel::PromptContext promptContext;
|
|
||||||
~LLModelWrapper() { delete llModel; }
|
~LLModelWrapper() { delete llModel; }
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -126,49 +124,44 @@ uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint6
|
|||||||
return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)});
|
return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)});
|
||||||
}
|
}
|
||||||
|
|
||||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
bool llmodel_prompt(llmodel_model model,
|
||||||
const char *prompt_template,
|
const char *prompt,
|
||||||
llmodel_prompt_callback prompt_callback,
|
llmodel_prompt_callback prompt_callback,
|
||||||
llmodel_response_callback response_callback,
|
llmodel_response_callback response_callback,
|
||||||
bool allow_context_shift,
|
llmodel_prompt_context *ctx,
|
||||||
llmodel_prompt_context *ctx,
|
const char **error)
|
||||||
bool special,
|
|
||||||
const char *fake_reply)
|
|
||||||
{
|
{
|
||||||
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
||||||
|
|
||||||
auto response_func = [response_callback](int32_t token_id, const std::string &response) {
|
// Copy the C prompt context
|
||||||
return response_callback(token_id, response.c_str());
|
LLModel::PromptContext promptContext {
|
||||||
|
.n_predict = ctx->n_predict,
|
||||||
|
.top_k = ctx->top_k,
|
||||||
|
.top_p = ctx->top_p,
|
||||||
|
.min_p = ctx->min_p,
|
||||||
|
.temp = ctx->temp,
|
||||||
|
.n_batch = ctx->n_batch,
|
||||||
|
.repeat_penalty = ctx->repeat_penalty,
|
||||||
|
.repeat_last_n = ctx->repeat_last_n,
|
||||||
|
.contextErase = ctx->context_erase,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Copy the C prompt context
|
auto prompt_func = [prompt_callback](std::span<const LLModel::Token> token_ids, bool cached) {
|
||||||
wrapper->promptContext.n_past = ctx->n_past;
|
return prompt_callback(token_ids.data(), token_ids.size(), cached);
|
||||||
wrapper->promptContext.n_predict = ctx->n_predict;
|
};
|
||||||
wrapper->promptContext.top_k = ctx->top_k;
|
auto response_func = [response_callback](LLModel::Token token_id, std::string_view piece) {
|
||||||
wrapper->promptContext.top_p = ctx->top_p;
|
return response_callback(token_id, piece.data());
|
||||||
wrapper->promptContext.min_p = ctx->min_p;
|
};
|
||||||
wrapper->promptContext.temp = ctx->temp;
|
|
||||||
wrapper->promptContext.n_batch = ctx->n_batch;
|
|
||||||
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
|
|
||||||
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
|
|
||||||
wrapper->promptContext.contextErase = ctx->context_erase;
|
|
||||||
|
|
||||||
// Call the C++ prompt method
|
// Call the C++ prompt method
|
||||||
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
|
try {
|
||||||
wrapper->promptContext, special,
|
wrapper->llModel->prompt(prompt, prompt_func, response_func, promptContext);
|
||||||
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt);
|
} catch (std::exception const &e) {
|
||||||
|
llmodel_set_error(error, e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Update the rest of the C prompt context
|
return true;
|
||||||
ctx->n_past = wrapper->promptContext.n_past;
|
|
||||||
ctx->n_predict = wrapper->promptContext.n_predict;
|
|
||||||
ctx->top_k = wrapper->promptContext.top_k;
|
|
||||||
ctx->top_p = wrapper->promptContext.top_p;
|
|
||||||
ctx->min_p = wrapper->promptContext.min_p;
|
|
||||||
ctx->temp = wrapper->promptContext.temp;
|
|
||||||
ctx->n_batch = wrapper->promptContext.n_batch;
|
|
||||||
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
|
|
||||||
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
|
|
||||||
ctx->context_erase = wrapper->promptContext.contextErase;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
float *llmodel_embed(
|
float *llmodel_embed(
|
||||||
@ -307,3 +300,21 @@ const char *llmodel_model_gpu_device_name(llmodel_model model)
|
|||||||
const auto *wrapper = static_cast<LLModelWrapper *>(model);
|
const auto *wrapper = static_cast<LLModelWrapper *>(model);
|
||||||
return wrapper->llModel->gpuDeviceName();
|
return wrapper->llModel->gpuDeviceName();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, const char **error)
|
||||||
|
{
|
||||||
|
auto *wrapper = static_cast<const LLModelWrapper *>(model);
|
||||||
|
try {
|
||||||
|
return wrapper->llModel->countPromptTokens(prompt);
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
llmodel_set_error(error, e.what());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback)
|
||||||
|
{
|
||||||
|
auto *wrapper = static_cast<const LLModelWrapper *>(model);
|
||||||
|
for (auto &[name, token] : wrapper->llModel->specialTokens())
|
||||||
|
callback(name.c_str(), token.c_str());
|
||||||
|
}
|
||||||
|
@ -4,232 +4,120 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <functional>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <regex>
|
#include <ranges>
|
||||||
#include <sstream>
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace ranges = std::ranges;
|
namespace ranges = std::ranges;
|
||||||
|
namespace views = std::ranges::views;
|
||||||
|
|
||||||
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err)
|
void LLModel::prompt(
|
||||||
{
|
std::string_view prompt,
|
||||||
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))");
|
const PromptCallback &promptCallback,
|
||||||
|
const ResponseCallback &responseCallback,
|
||||||
|
const PromptContext &promptCtx
|
||||||
|
) {
|
||||||
|
if (!isModelLoaded())
|
||||||
|
throw std::invalid_argument("Attempted to prompt an unloaded model.");
|
||||||
|
if (!supportsCompletion())
|
||||||
|
throw std::invalid_argument("Not a text completion model.");
|
||||||
|
if (!promptCtx.n_batch)
|
||||||
|
throw std::invalid_argument("Batch size cannot be zero.");
|
||||||
|
if (!promptCtx.n_predict)
|
||||||
|
return; // nothing requested
|
||||||
|
|
||||||
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex);
|
auto embd_inp = tokenize(prompt);
|
||||||
placeholders.clear();
|
if (embd_inp.empty())
|
||||||
placeholders.insert(placeholders.end(), it, std::sregex_iterator());
|
throw std::invalid_argument("Prompt tokenized to zero tokens.");
|
||||||
|
|
||||||
if (placeholders.size() > 2) {
|
if (auto res = decodePrompt(promptCallback, promptCtx, std::move(embd_inp)))
|
||||||
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size());
|
generateResponse(responseCallback, promptCtx, /*n_past*/ *res);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
|
|
||||||
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
|
|
||||||
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLModel::prompt(const std::string &prompt,
|
int32_t LLModel::countPromptTokens(std::string_view prompt) const
|
||||||
const std::string &promptTemplate,
|
|
||||||
std::function<bool(int32_t)> promptCallback,
|
|
||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
|
||||||
bool allowContextShift,
|
|
||||||
PromptContext &promptCtx,
|
|
||||||
bool special,
|
|
||||||
std::optional<std::string_view> fakeReply)
|
|
||||||
{
|
{
|
||||||
if (!isModelLoaded()) {
|
if (!isModelLoaded())
|
||||||
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
|
throw std::invalid_argument("Attempted to tokenize with an unloaded model.");
|
||||||
return;
|
return int32_t(tokenize(prompt).size());
|
||||||
}
|
|
||||||
|
|
||||||
if (!supportsCompletion()) {
|
|
||||||
std::string errorMessage = "ERROR: this model does not support text completion or chat!";
|
|
||||||
responseCallback(-1, errorMessage);
|
|
||||||
std::cerr << implementation().modelType() << " " << errorMessage << "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// sanity checks
|
|
||||||
if (promptCtx.n_past > contextLength()) {
|
|
||||||
std::ostringstream ss;
|
|
||||||
ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength();
|
|
||||||
throw std::out_of_range(ss.str());
|
|
||||||
}
|
|
||||||
if (promptCtx.n_past > inputLength()) {
|
|
||||||
std::ostringstream ss;
|
|
||||||
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << inputLength();
|
|
||||||
throw std::out_of_range(ss.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
|
|
||||||
|
|
||||||
// parse the prompt template
|
|
||||||
std::vector<std::smatch> placeholders;
|
|
||||||
{
|
|
||||||
std::string err;
|
|
||||||
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
|
|
||||||
responseCallback(-1, err);
|
|
||||||
std::cerr << err << "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
setTokenizeInputPosition(promptCtx.n_past);
|
|
||||||
|
|
||||||
// tokenize the user prompt
|
|
||||||
std::vector<Token> embd_inp;
|
|
||||||
if (placeholders.empty()) {
|
|
||||||
// this is unusual, but well-defined
|
|
||||||
std::cerr << __func__ << ": prompt template has no placeholder\n";
|
|
||||||
embd_inp = tokenize(promptTemplate, true);
|
|
||||||
} else {
|
|
||||||
// template: beginning of user prompt
|
|
||||||
const auto &phUser = placeholders[0];
|
|
||||||
std::string userPrefix(phUser.prefix());
|
|
||||||
if (!userPrefix.empty())
|
|
||||||
embd_inp = tokenize(userPrefix, true);
|
|
||||||
|
|
||||||
// user input (shouldn't have special token processing)
|
|
||||||
auto tokens = tokenize(prompt, special);
|
|
||||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
|
||||||
|
|
||||||
// template: end of user prompt + start of assistant prompt
|
|
||||||
size_t start = phUser.position() + phUser.length();
|
|
||||||
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
|
|
||||||
auto userToAsst = promptTemplate.substr(start, end - start);
|
|
||||||
if (!userToAsst.empty()) {
|
|
||||||
tokens = tokenize(userToAsst, true);
|
|
||||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// decode the user prompt
|
|
||||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, /*isResponse*/ false,
|
|
||||||
/*alwaysDecode*/ true))
|
|
||||||
return; // error
|
|
||||||
|
|
||||||
// decode the assistant's reply, either generated or spoofed
|
|
||||||
if (!fakeReply) {
|
|
||||||
generateResponse(responseCallback, allowContextShift, promptCtx);
|
|
||||||
} else {
|
|
||||||
embd_inp = tokenize(*fakeReply, false);
|
|
||||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true))
|
|
||||||
return; // error
|
|
||||||
}
|
|
||||||
|
|
||||||
// decode the rest of the prompt template
|
|
||||||
// template: end of assistant prompt
|
|
||||||
std::string asstSuffix;
|
|
||||||
if (placeholders.size() >= 2) {
|
|
||||||
size_t start = placeholders[1].position() + placeholders[1].length();
|
|
||||||
asstSuffix = promptTemplate.substr(start);
|
|
||||||
} else {
|
|
||||||
asstSuffix = "\n\n"; // default to a blank link, good for e.g. Alpaca
|
|
||||||
}
|
|
||||||
if (!asstSuffix.empty()) {
|
|
||||||
embd_inp = tokenize(asstSuffix, true);
|
|
||||||
decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns false on error
|
auto LLModel::decodePrompt(
|
||||||
bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
const PromptCallback &promptCallback,
|
||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
const PromptContext &promptCtx,
|
||||||
bool allowContextShift,
|
std::vector<Token> embd_inp
|
||||||
PromptContext &promptCtx,
|
) -> std::optional<int32_t>
|
||||||
std::vector<Token> embd_inp,
|
{
|
||||||
bool isResponse,
|
assert(!embd_inp.empty());
|
||||||
bool alwaysDecode) {
|
|
||||||
if ((int) embd_inp.size() > contextLength() - 4) {
|
|
||||||
// FIXME: (Adam) We should find a way to bubble these strings to the UI level to allow for
|
|
||||||
// translation
|
|
||||||
responseCallback(-1, "Your message was too long and could not be processed. Please try again with something shorter.");
|
|
||||||
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
|
|
||||||
" tokens and the context window is " << contextLength() << "!\n";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// FIXME(jared): There are mitigations for this situation, such as making room before
|
int32_t nCtx = contextLength();
|
||||||
// copying the prompt context, or restoring the KV cache when we restore the prompt
|
int32_t n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
|
||||||
// context.
|
|
||||||
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > contextLength()) {
|
|
||||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size()
|
|
||||||
<< ", n_ctx=" << contextLength() << "\n";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// always decode something before generating, even if cached
|
|
||||||
if (alwaysDecode && embd_inp.empty()) {
|
|
||||||
auto cache = inputTokens();
|
|
||||||
if (!promptCtx.n_past)
|
|
||||||
throw std::runtime_error("zero token prompt is not supported");
|
|
||||||
assert(!cache.empty());
|
|
||||||
embd_inp.push_back(cache.back());
|
|
||||||
promptCtx.n_past--;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the
|
// Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the
|
||||||
// requested n_past.
|
// requested n_past.
|
||||||
// This is used to skip unnecessary work when the prompt shares a common prefix with the previous result.
|
// This is used to skip unnecessary work when the prompt shares a common prefix with the previous result.
|
||||||
auto embd_inp_start = computeModelInputPosition(promptCtx, embd_inp);
|
int32_t nPast = computeModelInputPosition(embd_inp);
|
||||||
size_t start_offset = embd_inp_start - embd_inp.begin();
|
|
||||||
|
|
||||||
// always decode up to a full batch before generating, even if cached
|
// always decode up to a full batch before generating, even if cached
|
||||||
if (alwaysDecode)
|
nPast -= std::min(n_batch, nPast);
|
||||||
start_offset -= std::min(promptCtx.n_batch, int32_t(start_offset));
|
|
||||||
|
|
||||||
setModelInputPosition(promptCtx, promptCtx.n_past + start_offset);
|
// TODO(jared): generalize this to find the smallest new_embd_inp.size() - nPast given the cache
|
||||||
|
if (!nPast && int32_t(embd_inp.size()) > nCtx) {
|
||||||
|
// no cache hit -> shift the input before even processing
|
||||||
|
|
||||||
// execute the callback even for skipped tokens
|
int32_t nKeep = shouldAddBOS();
|
||||||
size_t i = 0;
|
auto newLength = int32_t(nCtx * (1.f - promptCtx.contextErase));
|
||||||
for (; i < start_offset; i++) {
|
int32_t nDiscard = int32_t(embd_inp.size()) - std::max(1, std::min(nCtx, newLength));
|
||||||
Token tok = embd_inp[i];
|
|
||||||
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
|
// execute the callback even for skipped tokens. this misrepresents the position of BOS but we don't care
|
||||||
if (!res)
|
auto discardedTokens = embd_inp | views::drop(nKeep) | views::take(nDiscard);
|
||||||
return false;
|
if (!promptCallback(discardedTokens, true))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
// erase nDiscard tokens
|
||||||
|
embd_inp.erase(discardedTokens.begin(), discardedTokens.end());
|
||||||
|
assert(int32_t(embd_inp.size()) <= nCtx);
|
||||||
|
|
||||||
|
// check the cache again, just in case
|
||||||
|
nPast = computeModelInputPosition(embd_inp);
|
||||||
|
nPast -= std::min(n_batch, nPast);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setModelInputPosition(nPast);
|
||||||
|
|
||||||
|
// execute the callback even for skipped tokens
|
||||||
|
if (!promptCallback(embd_inp | views::take(nPast), true))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
// process the prompt in batches
|
// process the prompt in batches
|
||||||
while (i < embd_inp.size()) {
|
for (int32_t i = nPast; i < embd_inp.size();) {
|
||||||
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
|
auto batch_end = std::min(i + n_batch, int32_t(embd_inp.size()));
|
||||||
std::span<const Token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
std::span batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||||
|
|
||||||
// Check if the context has run out...
|
// Check if the context has run out...
|
||||||
if (promptCtx.n_past + int32_t(batch.size()) > contextLength()) {
|
if (nPast + int32_t(batch.size()) > nCtx) {
|
||||||
assert(allowContextShift);
|
shiftContext(promptCtx, &nPast);
|
||||||
shiftContext(promptCtx);
|
assert(nPast + int32_t(batch.size()) <= nCtx);
|
||||||
assert(promptCtx.n_past + int32_t(batch.size()) <= contextLength());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!evalTokens(promptCtx, batch)) {
|
// FIXME(Adam): We should find a way to bubble these strings to the UI level to allow for translation
|
||||||
std::cerr << implementation().modelType() << " ERROR: Failed to process prompt\n";
|
if (!evalTokens(nPast, batch))
|
||||||
return false;
|
throw std::runtime_error("An internal error was encountered during prompt processing.");
|
||||||
}
|
|
||||||
|
|
||||||
size_t tokens = batch_end - i;
|
for (auto &tok : batch) {
|
||||||
for (size_t t = 0; t < tokens; ++t) {
|
appendInputToken(tok);
|
||||||
Token tok = batch[t];
|
nPast++;
|
||||||
appendInputToken(promptCtx, tok);
|
if (!promptCallback({ &tok, 1 }, false))
|
||||||
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
|
return std::nullopt;
|
||||||
if (!res)
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
i = batch_end;
|
i = batch_end;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return nPast;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -251,22 +139,16 @@ static std::string::size_type stringsOverlap(const std::string &s, const std::st
|
|||||||
return std::string::npos;
|
return std::string::npos;
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
void LLModel::generateResponse(
|
||||||
bool allowContextShift,
|
const ResponseCallback &responseCallback,
|
||||||
PromptContext &promptCtx) {
|
const PromptContext &promptCtx,
|
||||||
|
int32_t nPast
|
||||||
|
) {
|
||||||
static const char *stopSequences[] {
|
static const char *stopSequences[] {
|
||||||
"### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context",
|
"### System", "### Instruction", "### Human", "### User", "### Response", "### Assistant", "### Context",
|
||||||
|
"<|im_start|>", "<|im_end|>", "<|endoftext|>",
|
||||||
};
|
};
|
||||||
|
|
||||||
// Don't even start if there is no room
|
|
||||||
if (!promptCtx.n_predict)
|
|
||||||
return;
|
|
||||||
if (!allowContextShift && promptCtx.n_past >= contextLength()) {
|
|
||||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << contextLength()
|
|
||||||
<< "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
initSampler(promptCtx);
|
initSampler(promptCtx);
|
||||||
|
|
||||||
std::string cachedResponse;
|
std::string cachedResponse;
|
||||||
@ -281,25 +163,20 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
|||||||
cachedTokens.push_back(new_tok.value());
|
cachedTokens.push_back(new_tok.value());
|
||||||
cachedResponse += new_piece;
|
cachedResponse += new_piece;
|
||||||
|
|
||||||
auto accept = [this, &promptCtx, &new_tok, allowContextShift]() -> bool {
|
auto accept = [this, &promptCtx, &new_tok, &nPast] {
|
||||||
// Shift context if out of space
|
// Shift context if out of space
|
||||||
if (promptCtx.n_past >= contextLength()) {
|
if (nPast >= contextLength()) {
|
||||||
(void)allowContextShift;
|
shiftContext(promptCtx, &nPast);
|
||||||
assert(allowContextShift);
|
assert(nPast < contextLength());
|
||||||
shiftContext(promptCtx);
|
|
||||||
assert(promptCtx.n_past < contextLength());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accept the token
|
// Accept the token
|
||||||
Token tok = std::exchange(new_tok, std::nullopt).value();
|
Token tok = std::exchange(new_tok, std::nullopt).value();
|
||||||
if (!evalTokens(promptCtx, { &tok, 1 })) {
|
if (!evalTokens(nPast, { &tok, 1 }))
|
||||||
// TODO(jared): raise an exception
|
throw std::runtime_error("An internal error was encountered during response generation.");
|
||||||
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
appendInputToken(promptCtx, tok);
|
appendInputToken(tok);
|
||||||
return true;
|
nPast++;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check for EOS
|
// Check for EOS
|
||||||
@ -336,13 +213,6 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
|||||||
lengthLimit = cachedResponse.size() - new_piece.size();
|
lengthLimit = cachedResponse.size() - new_piece.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optionally stop if the context will run out
|
|
||||||
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= contextLength()) {
|
|
||||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx="
|
|
||||||
<< contextLength() << "\n";
|
|
||||||
stop = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Empty the cache, up to the length limit
|
// Empty the cache, up to the length limit
|
||||||
std::string::size_type responseLength = 0;
|
std::string::size_type responseLength = 0;
|
||||||
while (!cachedTokens.empty()) {
|
while (!cachedTokens.empty()) {
|
||||||
@ -359,8 +229,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
|||||||
cachedResponse.erase(cachedResponse.begin(), cachedResponse.begin() + piece.size());
|
cachedResponse.erase(cachedResponse.begin(), cachedResponse.begin() + piece.size());
|
||||||
|
|
||||||
// Accept the token, if needed (not cached)
|
// Accept the token, if needed (not cached)
|
||||||
if (cachedTokens.empty() && new_tok && !accept())
|
if (cachedTokens.empty() && new_tok)
|
||||||
return;
|
accept();
|
||||||
|
|
||||||
// Send the token
|
// Send the token
|
||||||
if (!responseCallback(tok, piece) || ++n_predicted >= promptCtx.n_predict) {
|
if (!responseCallback(tok, piece) || ++n_predicted >= promptCtx.n_predict) {
|
||||||
@ -379,8 +249,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
|||||||
assert(!cachedTokens.empty() && cachedTokens.back() == new_tok);
|
assert(!cachedTokens.empty() && cachedTokens.back() == new_tok);
|
||||||
if (stop) {
|
if (stop) {
|
||||||
cachedTokens.pop_back();
|
cachedTokens.pop_back();
|
||||||
} else if (!accept()) {
|
} else {
|
||||||
return;
|
accept();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -396,8 +266,6 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
|||||||
auto discard_start = inp.end() - cachedTokens.size();
|
auto discard_start = inp.end() - cachedTokens.size();
|
||||||
assert(std::equal(discard_start, inp.end(), cachedTokens.begin()));
|
assert(std::equal(discard_start, inp.end(), cachedTokens.begin()));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
promptCtx.n_past -= cachedTokens.size();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLModel::embed(
|
void LLModel::embed(
|
||||||
|
206
gpt4all-bindings/python/docs/gpt4all_desktop/chat_templates.md
Normal file
206
gpt4all-bindings/python/docs/gpt4all_desktop/chat_templates.md
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
## What are chat templates?
|
||||||
|
Natively, large language models only know how to complete plain text and do not know the difference between their input and their output. In order to support a chat with a person, LLMs are designed to use a template to convert the conversation to plain text using a specific format.
|
||||||
|
|
||||||
|
For a given model, it is important to use an appropriate chat template, as each model is designed to work best with a specific format. The chat templates included with the built-in models should be sufficient for most purposes.
|
||||||
|
|
||||||
|
There are two reasons you would want to alter the chat template:
|
||||||
|
|
||||||
|
- You are sideloading a model and there is no chat template available,
|
||||||
|
- You would like to have greater control over the input to the LLM than a system message provides.
|
||||||
|
|
||||||
|
|
||||||
|
## What is a system message?
|
||||||
|
A system message is a message that controls the responses from the LLM in a way that affects the entire conversation. System messages can be short, such as "Speak like a pirate.", or they can be long and contain a lot of context for the LLM to keep in mind.
|
||||||
|
|
||||||
|
Not all models are designed to use a system message, so they work with some models better than others.
|
||||||
|
|
||||||
|
|
||||||
|
## How do I customize the chat template or system message?
|
||||||
|
To customize the chat template or system message, go to Settings > Model. Make sure to select the correct model at the top. If you clone a model, you can use a different chat template or system message from the base model, enabling you to use different settings for each conversation.
|
||||||
|
|
||||||
|
These settings take effect immediately. After changing them, you can click "Redo last response" in the chat view, and the response will take the new settings into account.
|
||||||
|
|
||||||
|
|
||||||
|
## Do I need to write a chat template?
|
||||||
|
You typically do not need to write your own chat template. The exception is models that are not in the official model list and do not come with a chat template built-in. These will show a "Clear" option above the chat template field in the Model Settings page instead of a "Reset" option. See the section on [finding] or [creating] a chat template.
|
||||||
|
|
||||||
|
[finding]: #how-do-i-find-a-chat-template
|
||||||
|
[creating]: #advanced-how-do-chat-templates-work
|
||||||
|
|
||||||
|
|
||||||
|
## What changed in GPT4All v3.5?
|
||||||
|
GPT4All v3.5 overhauled the chat template system. There are three crucial differences:
|
||||||
|
|
||||||
|
- The chat template now formats an entire conversation instead of a single pair of messages,
|
||||||
|
- The chat template now uses Jinja syntax instead of `%1` and `%2` placeholders,
|
||||||
|
- And the system message should no longer contain control tokens or trailing whitespace.
|
||||||
|
|
||||||
|
If you are using any chat templates or system messages that had been added or altered from the default before upgrading to GPT4All v3.5 or newer, these will no longer work. See below for how to solve common errors you may see after upgrading.
|
||||||
|
|
||||||
|
|
||||||
|
## Error/Warning: System message is not plain text.
|
||||||
|
This is easy to fix. Go to the model's settings and look at the system prompt. There are three things to look for:
|
||||||
|
|
||||||
|
- Control tokens such as `<|im_start|>`, `<|start_header_id|>`, or `<|system|>`
|
||||||
|
- A prefix such as `### System` or `SYSTEM:`
|
||||||
|
- Trailing whitespace, such as a space character or blank line.
|
||||||
|
|
||||||
|
If you see any of these things, remove them. For example, this legacy system prompt:
|
||||||
|
```
|
||||||
|
<|start_header_id|>system<|end_header_id|>
|
||||||
|
You are a helpful assistant.<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
Should become this:
|
||||||
|
```
|
||||||
|
You are a helpful assistant.
|
||||||
|
```
|
||||||
|
|
||||||
|
If you do not see anything that needs to be changed, you can dismiss the error by making a minor modification to the message and then changing it back.
|
||||||
|
|
||||||
|
If you see a warning, your system message does not appear to be plain text. If you believe this warning is incorrect, it can be safely ignored. If in doubt, ask on the [Discord].
|
||||||
|
|
||||||
|
[Discord]: https://discord.gg/mGZE39AS3e
|
||||||
|
|
||||||
|
|
||||||
|
## Error: Legacy system prompt needs to be updated in Settings.
|
||||||
|
This is the same as [above][above-1], but appears on the chat page.
|
||||||
|
|
||||||
|
[above-1]: #errorwarning-system-message-is-not-plain-text
|
||||||
|
|
||||||
|
|
||||||
|
## Error/Warning: Chat template is not in Jinja format.
|
||||||
|
This is the result of attempting to use an old-style template (possibly from a previous version) in GPT4All 3.5+.
|
||||||
|
|
||||||
|
Go to the Model Settings page and select the affected model. If you see a "Reset" button, and you have not intentionally modified the prompt template, you can click "Reset". Otherwise, this is what you can do:
|
||||||
|
|
||||||
|
1. Back up your chat template by copying it safely to a text file and saving it. In the next step, it will be removed from GPT4All.
|
||||||
|
2. Click "Reset" or "Clear".
|
||||||
|
3. If you clicked "Clear", the chat template is now gone. Follow the steps to [find][finding] or [create][creating] a basic chat template for your model.
|
||||||
|
4. Customize the chat template to suit your needs. For help, read the section about [creating] a chat template.
|
||||||
|
|
||||||
|
|
||||||
|
## Error: Legacy prompt template needs to be updated in Settings.
|
||||||
|
This is the same as [above][above-2], but appears on the chat page.
|
||||||
|
|
||||||
|
[above-2]: #errorwarning-chat-template-is-not-in-jinja-format
|
||||||
|
|
||||||
|
|
||||||
|
## The chat template has a syntax error.
|
||||||
|
If there is a syntax error while editing the chat template, the details will be displayed in an error message above the input box. This could be because the chat template is not actually in Jinja format (see [above][above-2]).
|
||||||
|
|
||||||
|
Otherwise, you have either typed something correctly, or the model comes with a template that is incompatible with GPT4All. See [the below section][creating] on creating chat templates and make sure that everything is correct. When in doubt, ask on the [Discord].
|
||||||
|
|
||||||
|
|
||||||
|
## Error: No chat template configured.
|
||||||
|
This may appear for models that are not from the official model list and do not include a chat template. Older versions of GPT4All picked a poor default in this case. You will get much better results if you follow the steps to [find][finding] or [create][creating] a chat template for your model.
|
||||||
|
|
||||||
|
|
||||||
|
## Error: The chat template cannot be blank.
|
||||||
|
If the button above the chat template on the Model Settings page says "Clear", see [above][above-3]. If you see "Reset", click that button to restore a reasonable default. Also see the section on [syntax errors][chat-syntax-error].
|
||||||
|
|
||||||
|
[above-3]: #error-no-chat-template-configured
|
||||||
|
[chat-syntax-error]: #the-chat-template-has-a-syntax-error
|
||||||
|
|
||||||
|
|
||||||
|
## How do I find a chat template?
|
||||||
|
When in doubt, you can always ask the [Discord] community for help. Below are the instructions to find one on your own.
|
||||||
|
|
||||||
|
The authoritative source for a model's chat template is the HuggingFace repo that the original (non-GGUF) model came from. First, you should find this page. If you just have a model file, you can try a google search for the model's name. If you know the page you downloaded the GGUF model from, its README usually links to the original non-GGUF model.
|
||||||
|
|
||||||
|
Once you have located the original model, there are two methods you can use to extract its chat template. Pick whichever one you are most comfortable with.
|
||||||
|
|
||||||
|
### Using the CLI (all models)
|
||||||
|
1. Install `jq` using your preferred package manager - e.g. Chocolatey (Windows), Homebrew (macOS), or apt (Ubuntu).
|
||||||
|
2. Download `tokenizer_config.json` from the model's "Files and versions" tab.
|
||||||
|
3. Open a command prompt in the directory which you have downloaded the model file.
|
||||||
|
4. Run `jq -r ".chat_template" tokenizer_config.json`. This shows the chat template in a human-readable form. You can copy this and paste it into the settings page.
|
||||||
|
5. (Optional) You can save the output to a text file like this: `jq -r ".chat_template" tokenizer_config.json >chat_template.txt`
|
||||||
|
|
||||||
|
If the output is "null", the model does not provide a chat template. See the [below instructions][creating] on creating a chat template.
|
||||||
|
|
||||||
|
### Python (open models)
|
||||||
|
1. Install `transformers` using your preferred python package manager, e.g. `pip install transformers`. Make sure it is at least version v4.43.0.
|
||||||
|
2. Copy the ID of the HuggingFace model, using the clipboard icon next to the name. For example, if the URL is `https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B`, the ID is `NousResearch/Hermes-2-Pro-Llama-3-8B`.
|
||||||
|
3. Open a python interpreter (`python`) and run the following commands. Change the model ID in the example to the one you copied.
|
||||||
|
```
|
||||||
|
>>> from transformers import AutoTokenizer
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained('NousResearch/Hermes-2-Pro-Llama-3-8B')
|
||||||
|
>>> print(tokenizer.get_chat_template())
|
||||||
|
```
|
||||||
|
You can copy the output and paste it into the settings page.
|
||||||
|
4. (Optional) You can save the output to a text file like this:
|
||||||
|
```
|
||||||
|
>>> open('chat_template.txt', 'w').write(tokenizer.get_chat_template())
|
||||||
|
```
|
||||||
|
|
||||||
|
If you get a ValueError exception, this model does not provide a chat template. See the [below instructions][creating] on creating a chat template.
|
||||||
|
|
||||||
|
|
||||||
|
### Python (gated models)
|
||||||
|
Some models, such as Llama and Mistral, do not allow public access to their chat template. You must either use the CLI method above, or follow the following instructions to use Python:
|
||||||
|
|
||||||
|
1. For these steps, you must have git and git-lfs installed.
|
||||||
|
2. You must have a HuggingFace account and be logged in.
|
||||||
|
3. You must already have access to the gated model. Otherwise, request access.
|
||||||
|
4. You must have an SSH key configured for git access to HuggingFace.
|
||||||
|
5. `git clone` the model's HuggingFace repo using the SSH clone URL. There is no need to download the entire model, which is very large. A good way to do this on Linux is:
|
||||||
|
```console
|
||||||
|
$ GIT_LFS_SKIP_SMUDGE=1 git clone hf.co:meta-llama/Llama-3.1-8B-Instruct.git
|
||||||
|
$ cd Llama-3.1-8B-Instruct
|
||||||
|
$ git lfs pull -I "tokenizer.*"
|
||||||
|
```
|
||||||
|
6. Follow the above instructions for open models, but replace the model ID with the path to the directory containing `tokenizer\_config.json`:
|
||||||
|
```
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained('.')
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Advanced: How do chat templates work?
|
||||||
|
The chat template is applied to the entire conversation you see in the chat window. The template loops over the list of messages, each containing `role` and `content` fields. `role` is either `user`, `assistant`, or `system`.
|
||||||
|
|
||||||
|
GPT4All also supports the special variables `bos_token`, `eos_token`, and `add_generation_prompt`. See the [HuggingFace docs] for what those do.
|
||||||
|
|
||||||
|
[HuggingFace docs]: https://huggingface.co/docs/transformers/v4.46.3/en/chat_templating#special-variables
|
||||||
|
|
||||||
|
|
||||||
|
## Advanced: How do I make a chat template?
|
||||||
|
The best way to create a chat template is to start by using an existing one as a reference. Then, modify it to use the format documented for the given model. Its README page may explicitly give an example of its template. Or, it may mention the name of a well-known standard template, such as ChatML, Alpaca, Vicuna. GPT4All does not yet include presets for these templates, so they will have to be found in other models or taken from the community.
|
||||||
|
|
||||||
|
For more information, see the very helpful [HuggingFace guide]. Some of this is not applicable, such as the information about tool calling and RAG - GPT4All implements those features differently.
|
||||||
|
|
||||||
|
Some models use a prompt template that does not intuitively map to a multi-turn chat, because it is more intended for single instructions. The [FastChat] implementation of these templates is a useful reference for the correct way to extend them to multiple messages.
|
||||||
|
|
||||||
|
[HuggingFace guide]: https://huggingface.co/docs/transformers/v4.46.3/en/chat_templating#advanced-template-writing-tips
|
||||||
|
[FastChat]: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
|
|
||||||
|
|
||||||
|
# Advanced: What are GPT4All v1 templates?
|
||||||
|
GPT4All supports its own template syntax, which is nonstandard but provides complete control over the way LocalDocs sources and file attachments are inserted into the conversation. These templates begin with `{# gpt4all v1 #}` and look similar to the example below.
|
||||||
|
|
||||||
|
For standard templates, GPT4All combines the user message, sources, and attachments into the `content` field. For GPT4All v1 templates, this is not done, so they must be used directly in the template for those features to work correctly.
|
||||||
|
|
||||||
|
```jinja
|
||||||
|
{# gpt4all v1 #}
|
||||||
|
{%- for message in messages %}
|
||||||
|
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
|
||||||
|
{%- if message['role'] == 'user' %}
|
||||||
|
{%- for source in message['sources'] %}
|
||||||
|
{%- if loop.first %}
|
||||||
|
{{- '### Context:\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- 'Collection: ' + source['collection'] + '\n' +
|
||||||
|
'Path: ' + source['path'] + '\n' +
|
||||||
|
'Excerpt: ' + source['text'] + '\n\n' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for attachment in message['prompt_attachments'] %}
|
||||||
|
{{- attachment['processed_content'] + '\n\n' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- message['content'] | trim }}
|
||||||
|
{{- '<|eot_id|>' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
```
|
@ -9,7 +9,7 @@ import textwrap
|
|||||||
import threading
|
import threading
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Literal, NoReturn, TypeVar, overload
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload
|
||||||
|
|
||||||
if sys.version_info >= (3, 9):
|
if sys.version_info >= (3, 9):
|
||||||
import importlib.resources as importlib_resources
|
import importlib.resources as importlib_resources
|
||||||
@ -23,7 +23,9 @@ else:
|
|||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import ParamSpec, TypeAlias
|
||||||
|
T = TypeVar("T")
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
|
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
|
||||||
|
|
||||||
@ -31,7 +33,7 @@ cuda_found: bool = False
|
|||||||
|
|
||||||
|
|
||||||
# TODO(jared): use operator.call after we drop python 3.10 support
|
# TODO(jared): use operator.call after we drop python 3.10 support
|
||||||
def _operator_call(obj, /, *args, **kwargs):
|
def _operator_call(obj: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
return obj(*args, **kwargs)
|
return obj(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -116,16 +118,15 @@ llmodel = load_llmodel_library()
|
|||||||
|
|
||||||
class LLModelPromptContext(ctypes.Structure):
|
class LLModelPromptContext(ctypes.Structure):
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("n_past", ctypes.c_int32),
|
("n_predict", ctypes.c_int32),
|
||||||
("n_predict", ctypes.c_int32),
|
("top_k", ctypes.c_int32),
|
||||||
("top_k", ctypes.c_int32),
|
("top_p", ctypes.c_float),
|
||||||
("top_p", ctypes.c_float),
|
("min_p", ctypes.c_float),
|
||||||
("min_p", ctypes.c_float),
|
("temp", ctypes.c_float),
|
||||||
("temp", ctypes.c_float),
|
("n_batch", ctypes.c_int32),
|
||||||
("n_batch", ctypes.c_int32),
|
|
||||||
("repeat_penalty", ctypes.c_float),
|
("repeat_penalty", ctypes.c_float),
|
||||||
("repeat_last_n", ctypes.c_int32),
|
("repeat_last_n", ctypes.c_int32),
|
||||||
("context_erase", ctypes.c_float),
|
("context_erase", ctypes.c_float),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -157,23 +158,21 @@ llmodel.llmodel_required_mem.restype = ctypes.c_size_t
|
|||||||
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
||||||
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
||||||
|
|
||||||
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
|
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_int32), ctypes.c_size_t, ctypes.c_bool)
|
||||||
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
||||||
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
|
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
|
||||||
|
SpecialTokenCallback = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p)
|
||||||
|
|
||||||
llmodel.llmodel_prompt.argtypes = [
|
llmodel.llmodel_prompt.argtypes = [
|
||||||
ctypes.c_void_p,
|
ctypes.c_void_p,
|
||||||
ctypes.c_char_p,
|
ctypes.c_char_p,
|
||||||
ctypes.c_char_p,
|
|
||||||
PromptCallback,
|
PromptCallback,
|
||||||
ResponseCallback,
|
ResponseCallback,
|
||||||
ctypes.c_bool,
|
|
||||||
ctypes.POINTER(LLModelPromptContext),
|
ctypes.POINTER(LLModelPromptContext),
|
||||||
ctypes.c_bool,
|
ctypes.POINTER(ctypes.c_char_p),
|
||||||
ctypes.c_char_p,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
llmodel.llmodel_prompt.restype = None
|
llmodel.llmodel_prompt.restype = ctypes.c_bool
|
||||||
|
|
||||||
llmodel.llmodel_embed.argtypes = [
|
llmodel.llmodel_embed.argtypes = [
|
||||||
ctypes.c_void_p,
|
ctypes.c_void_p,
|
||||||
@ -222,6 +221,12 @@ llmodel.llmodel_model_backend_name.restype = ctypes.c_char_p
|
|||||||
llmodel.llmodel_model_gpu_device_name.argtypes = [ctypes.c_void_p]
|
llmodel.llmodel_model_gpu_device_name.argtypes = [ctypes.c_void_p]
|
||||||
llmodel.llmodel_model_gpu_device_name.restype = ctypes.c_char_p
|
llmodel.llmodel_model_gpu_device_name.restype = ctypes.c_char_p
|
||||||
|
|
||||||
|
llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char_p)]
|
||||||
|
llmodel.llmodel_count_prompt_tokens.restype = ctypes.c_int32
|
||||||
|
|
||||||
|
llmodel.llmodel_model_foreach_special_token.argtypes = [ctypes.c_void_p, SpecialTokenCallback]
|
||||||
|
llmodel.llmodel_model_foreach_special_token.restype = None
|
||||||
|
|
||||||
ResponseCallbackType = Callable[[int, str], bool]
|
ResponseCallbackType = Callable[[int, str], bool]
|
||||||
RawResponseCallbackType = Callable[[int, bytes], bool]
|
RawResponseCallbackType = Callable[[int, bytes], bool]
|
||||||
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
|
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
|
||||||
@ -266,7 +271,6 @@ class LLModel:
|
|||||||
self.model_path = model_path.encode()
|
self.model_path = model_path.encode()
|
||||||
self.n_ctx = n_ctx
|
self.n_ctx = n_ctx
|
||||||
self.ngl = ngl
|
self.ngl = ngl
|
||||||
self.context: LLModelPromptContext | None = None
|
|
||||||
self.buffer = bytearray()
|
self.buffer = bytearray()
|
||||||
self.buff_expecting_cont_bytes: int = 0
|
self.buff_expecting_cont_bytes: int = 0
|
||||||
|
|
||||||
@ -286,6 +290,10 @@ class LLModel:
|
|||||||
|
|
||||||
raise RuntimeError(f"Unable to instantiate model: {errmsg}")
|
raise RuntimeError(f"Unable to instantiate model: {errmsg}")
|
||||||
self.model: ctypes.c_void_p | None = model
|
self.model: ctypes.c_void_p | None = model
|
||||||
|
self.special_tokens_map: dict[str, str] = {}
|
||||||
|
llmodel.llmodel_model_foreach_special_token(
|
||||||
|
self.model, lambda n, t: self.special_tokens_map.__setitem__(n.decode(), t.decode()),
|
||||||
|
)
|
||||||
|
|
||||||
def __del__(self, llmodel=llmodel):
|
def __del__(self, llmodel=llmodel):
|
||||||
if hasattr(self, 'model'):
|
if hasattr(self, 'model'):
|
||||||
@ -312,6 +320,19 @@ class LLModel:
|
|||||||
dev = llmodel.llmodel_model_gpu_device_name(self.model)
|
dev = llmodel.llmodel_model_gpu_device_name(self.model)
|
||||||
return None if dev is None else dev.decode()
|
return None if dev is None else dev.decode()
|
||||||
|
|
||||||
|
def count_prompt_tokens(self, prompt: str) -> int:
|
||||||
|
if self.model is None:
|
||||||
|
self._raise_closed()
|
||||||
|
err = ctypes.c_char_p()
|
||||||
|
n_tok = llmodel.llmodel_count_prompt_tokens(self.model, prompt, ctypes.byref(err))
|
||||||
|
if n_tok < 0:
|
||||||
|
s = err.value
|
||||||
|
errmsg = 'null' if s is None else s.decode()
|
||||||
|
raise RuntimeError(f'Unable to count prompt tokens: {errmsg}')
|
||||||
|
return n_tok
|
||||||
|
|
||||||
|
llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_gpus(mem_required: int = 0) -> list[str]:
|
def list_gpus(mem_required: int = 0) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@ -375,48 +396,6 @@ class LLModel:
|
|||||||
raise Exception("Model not loaded")
|
raise Exception("Model not loaded")
|
||||||
return llmodel.llmodel_threadCount(self.model)
|
return llmodel.llmodel_threadCount(self.model)
|
||||||
|
|
||||||
def _set_context(
|
|
||||||
self,
|
|
||||||
n_predict: int = 4096,
|
|
||||||
top_k: int = 40,
|
|
||||||
top_p: float = 0.9,
|
|
||||||
min_p: float = 0.0,
|
|
||||||
temp: float = 0.1,
|
|
||||||
n_batch: int = 8,
|
|
||||||
repeat_penalty: float = 1.2,
|
|
||||||
repeat_last_n: int = 10,
|
|
||||||
context_erase: float = 0.75,
|
|
||||||
reset_context: bool = False,
|
|
||||||
):
|
|
||||||
if self.context is None:
|
|
||||||
context = LLModelPromptContext(
|
|
||||||
n_past=0,
|
|
||||||
n_predict=n_predict,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
min_p=min_p,
|
|
||||||
temp=temp,
|
|
||||||
n_batch=n_batch,
|
|
||||||
repeat_penalty=repeat_penalty,
|
|
||||||
repeat_last_n=repeat_last_n,
|
|
||||||
context_erase=context_erase,
|
|
||||||
)
|
|
||||||
self.context = context
|
|
||||||
else:
|
|
||||||
context = self.context
|
|
||||||
if reset_context:
|
|
||||||
self.context.n_past = 0
|
|
||||||
|
|
||||||
self.context.n_predict = n_predict
|
|
||||||
self.context.top_k = top_k
|
|
||||||
self.context.top_p = top_p
|
|
||||||
self.context.min_p = min_p
|
|
||||||
self.context.temp = temp
|
|
||||||
self.context.n_batch = n_batch
|
|
||||||
self.context.repeat_penalty = repeat_penalty
|
|
||||||
self.context.repeat_last_n = repeat_last_n
|
|
||||||
self.context.context_erase = context_erase
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
self, text: str, prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
self, text: str, prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
||||||
@ -486,20 +465,18 @@ class LLModel:
|
|||||||
|
|
||||||
def prompt_model(
|
def prompt_model(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt : str,
|
||||||
prompt_template: str,
|
callback : ResponseCallbackType,
|
||||||
callback: ResponseCallbackType,
|
n_predict : int = 4096,
|
||||||
n_predict: int = 4096,
|
top_k : int = 40,
|
||||||
top_k: int = 40,
|
top_p : float = 0.9,
|
||||||
top_p: float = 0.9,
|
min_p : float = 0.0,
|
||||||
min_p: float = 0.0,
|
temp : float = 0.1,
|
||||||
temp: float = 0.1,
|
n_batch : int = 8,
|
||||||
n_batch: int = 8,
|
repeat_penalty : float = 1.2,
|
||||||
repeat_penalty: float = 1.2,
|
repeat_last_n : int = 10,
|
||||||
repeat_last_n: int = 10,
|
context_erase : float = 0.75,
|
||||||
context_erase: float = 0.75,
|
reset_context : bool = False,
|
||||||
reset_context: bool = False,
|
|
||||||
special: bool = False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate response from model from a prompt.
|
Generate response from model from a prompt.
|
||||||
@ -522,34 +499,38 @@ class LLModel:
|
|||||||
self.buffer.clear()
|
self.buffer.clear()
|
||||||
self.buff_expecting_cont_bytes = 0
|
self.buff_expecting_cont_bytes = 0
|
||||||
|
|
||||||
self._set_context(
|
context = LLModelPromptContext(
|
||||||
n_predict=n_predict,
|
n_predict = n_predict,
|
||||||
top_k=top_k,
|
top_k = top_k,
|
||||||
top_p=top_p,
|
top_p = top_p,
|
||||||
min_p=min_p,
|
min_p = min_p,
|
||||||
temp=temp,
|
temp = temp,
|
||||||
n_batch=n_batch,
|
n_batch = n_batch,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty = repeat_penalty,
|
||||||
repeat_last_n=repeat_last_n,
|
repeat_last_n = repeat_last_n,
|
||||||
context_erase=context_erase,
|
context_erase = context_erase,
|
||||||
reset_context=reset_context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
llmodel.llmodel_prompt(
|
error_msg: bytes | None = None
|
||||||
|
def error_callback(msg: bytes) -> None:
|
||||||
|
nonlocal error_msg
|
||||||
|
error_msg = msg
|
||||||
|
|
||||||
|
err = ctypes.c_char_p()
|
||||||
|
if not llmodel.llmodel_prompt(
|
||||||
self.model,
|
self.model,
|
||||||
ctypes.c_char_p(prompt.encode()),
|
ctypes.c_char_p(prompt.encode()),
|
||||||
ctypes.c_char_p(prompt_template.encode()),
|
|
||||||
PromptCallback(self._prompt_callback),
|
PromptCallback(self._prompt_callback),
|
||||||
ResponseCallback(self._callback_decoder(callback)),
|
ResponseCallback(self._callback_decoder(callback)),
|
||||||
True,
|
context,
|
||||||
self.context,
|
ctypes.byref(err),
|
||||||
special,
|
):
|
||||||
ctypes.c_char_p(),
|
s = err.value
|
||||||
)
|
raise RuntimeError(f"prompt error: {'null' if s is None else s.decode()}")
|
||||||
|
|
||||||
def prompt_model_streaming(
|
def prompt_model_streaming(
|
||||||
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs: Any,
|
||||||
) -> Iterable[str]:
|
) -> Iterator[str]:
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self._raise_closed()
|
self._raise_closed()
|
||||||
|
|
||||||
@ -568,15 +549,15 @@ class LLModel:
|
|||||||
|
|
||||||
return _generator_callback
|
return _generator_callback
|
||||||
|
|
||||||
def run_llmodel_prompt(prompt: str, prompt_template: str, callback: ResponseCallbackType, **kwargs):
|
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
|
||||||
self.prompt_model(prompt, prompt_template, callback, **kwargs)
|
self.prompt_model(prompt, callback, **kwargs)
|
||||||
output_queue.put(Sentinel.TERMINATING_SYMBOL)
|
output_queue.put(Sentinel.TERMINATING_SYMBOL)
|
||||||
|
|
||||||
# Kick off llmodel_prompt in separate thread so we can return generator
|
# Kick off llmodel_prompt in separate thread so we can return generator
|
||||||
# immediately
|
# immediately
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=run_llmodel_prompt,
|
target=run_llmodel_prompt,
|
||||||
args=(prompt, prompt_template, _generator_callback_wrapper(callback)),
|
args=(prompt, _generator_callback_wrapper(callback)),
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
thread.start()
|
thread.start()
|
||||||
@ -631,5 +612,5 @@ class LLModel:
|
|||||||
|
|
||||||
# Empty prompt callback
|
# Empty prompt callback
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _prompt_callback(token_id: int) -> bool:
|
def _prompt_callback(token_ids: ctypes._Pointer[ctypes.c_int32], n_token_ids: int, cached: bool) -> bool:
|
||||||
return True
|
return True
|
||||||
|
@ -4,37 +4,66 @@ Python only API for running all GPT4All models.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
|
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, NamedTuple, NoReturn, Protocol, TypedDict, overload
|
||||||
|
|
||||||
|
import jinja2
|
||||||
import requests
|
import requests
|
||||||
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||||
from requests.exceptions import ChunkedEncodingError
|
from requests.exceptions import ChunkedEncodingError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from urllib3.exceptions import IncompleteRead, ProtocolError
|
from urllib3.exceptions import IncompleteRead, ProtocolError
|
||||||
|
|
||||||
from ._pyllmodel import (CancellationError as CancellationError, EmbCancelCallbackType, EmbedResult as EmbedResult,
|
from ._pyllmodel import (CancellationError as CancellationError, EmbCancelCallbackType, EmbedResult as EmbedResult,
|
||||||
LLModel, ResponseCallbackType, empty_response_callback)
|
LLModel, ResponseCallbackType, _operator_call, empty_response_callback)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import Self, TypeAlias
|
from typing_extensions import Self, TypeAlias
|
||||||
|
|
||||||
if sys.platform == 'darwin':
|
if sys.platform == "darwin":
|
||||||
import fcntl
|
import fcntl
|
||||||
|
|
||||||
# TODO: move to config
|
# TODO: move to config
|
||||||
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
|
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
|
||||||
|
|
||||||
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n"
|
ConfigType: TypeAlias = "dict[str, Any]"
|
||||||
|
|
||||||
ConfigType: TypeAlias = 'dict[str, Any]'
|
# Environment setup adapted from HF transformers
|
||||||
MessageType: TypeAlias = 'dict[str, str]'
|
@_operator_call
|
||||||
|
def _jinja_env() -> ImmutableSandboxedEnvironment:
|
||||||
|
def raise_exception(message: str) -> NoReturn:
|
||||||
|
raise jinja2.exceptions.TemplateError(message)
|
||||||
|
|
||||||
|
def tojson(obj: Any, indent: int | None = None) -> str:
|
||||||
|
return json.dumps(obj, ensure_ascii=False, indent=indent)
|
||||||
|
|
||||||
|
def strftime_now(fmt: str) -> str:
|
||||||
|
return datetime.now().strftime(fmt)
|
||||||
|
|
||||||
|
env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
||||||
|
env.filters["tojson" ] = tojson
|
||||||
|
env.globals["raise_exception"] = raise_exception
|
||||||
|
env.globals["strftime_now" ] = strftime_now
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(TypedDict):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSession(NamedTuple):
|
||||||
|
template: jinja2.Template
|
||||||
|
history: list[MessageType]
|
||||||
|
|
||||||
|
|
||||||
class Embed4All:
|
class Embed4All:
|
||||||
@ -54,7 +83,7 @@ class Embed4All:
|
|||||||
kwargs: Remaining keyword arguments are passed to the `GPT4All` constructor.
|
kwargs: Remaining keyword arguments are passed to the `GPT4All` constructor.
|
||||||
"""
|
"""
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
|
model_name = "all-MiniLM-L6-v2.gguf2.f16.gguf"
|
||||||
self.gpt4all = GPT4All(model_name, n_threads=n_threads, device=device, **kwargs)
|
self.gpt4all = GPT4All(model_name, n_threads=n_threads, device=device, **kwargs)
|
||||||
|
|
||||||
def __enter__(self) -> Self:
|
def __enter__(self) -> Self:
|
||||||
@ -145,18 +174,18 @@ class Embed4All:
|
|||||||
dimensionality = -1
|
dimensionality = -1
|
||||||
else:
|
else:
|
||||||
if dimensionality <= 0:
|
if dimensionality <= 0:
|
||||||
raise ValueError(f'Dimensionality must be None or a positive integer, got {dimensionality}')
|
raise ValueError(f"Dimensionality must be None or a positive integer, got {dimensionality}")
|
||||||
if dimensionality < self.MIN_DIMENSIONALITY:
|
if dimensionality < self.MIN_DIMENSIONALITY:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f'Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}.'
|
f"Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}."
|
||||||
' Performance may be degraded.'
|
" Performance may be degraded."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
do_mean = {"mean": True, "truncate": False}[long_text_mode]
|
do_mean = {"mean": True, "truncate": False}[long_text_mode]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}")
|
raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}")
|
||||||
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas, cancel_cb)
|
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas, cancel_cb)
|
||||||
return result if return_dict else result['embeddings']
|
return result if return_dict else result["embeddings"]
|
||||||
|
|
||||||
|
|
||||||
class GPT4All:
|
class GPT4All:
|
||||||
@ -204,8 +233,7 @@ class GPT4All:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self._history: list[MessageType] | None = None
|
self._chat_session: ChatSession | None = None
|
||||||
self._current_prompt_template: str = "{0}"
|
|
||||||
|
|
||||||
device_init = None
|
device_init = None
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
@ -264,7 +292,13 @@ class GPT4All:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def current_chat_session(self) -> list[MessageType] | None:
|
def current_chat_session(self) -> list[MessageType] | None:
|
||||||
return None if self._history is None else list(self._history)
|
return None if self._chat_session is None else self._chat_session.history
|
||||||
|
|
||||||
|
@current_chat_session.setter
|
||||||
|
def current_chat_session(self, history: list[MessageType]) -> None:
|
||||||
|
if self._chat_session is None:
|
||||||
|
raise ValueError("current_chat_session may only be set when there is an active chat session")
|
||||||
|
self._chat_session.history[:] = history
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_models() -> list[ConfigType]:
|
def list_models() -> list[ConfigType]:
|
||||||
@ -276,7 +310,7 @@ class GPT4All:
|
|||||||
"""
|
"""
|
||||||
resp = requests.get("https://gpt4all.io/models/models3.json")
|
resp = requests.get("https://gpt4all.io/models/models3.json")
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
raise ValueError(f'Request failed: HTTP {resp.status_code} {resp.reason}')
|
raise ValueError(f"Request failed: HTTP {resp.status_code} {resp.reason}")
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -306,15 +340,9 @@ class GPT4All:
|
|||||||
# get the config for the model
|
# get the config for the model
|
||||||
config: ConfigType = {}
|
config: ConfigType = {}
|
||||||
if allow_download:
|
if allow_download:
|
||||||
available_models = cls.list_models()
|
models = cls.list_models()
|
||||||
|
if (model := next((m for m in models if m["filename"] == model_filename), None)) is not None:
|
||||||
for m in available_models:
|
config.update(model)
|
||||||
if model_filename == m["filename"]:
|
|
||||||
tmpl = m.get("promptTemplate", DEFAULT_PROMPT_TEMPLATE)
|
|
||||||
# change to Python-style formatting
|
|
||||||
m["promptTemplate"] = tmpl.replace("%1", "{0}", 1).replace("%2", "{1}", 1)
|
|
||||||
config.update(m)
|
|
||||||
break
|
|
||||||
|
|
||||||
# Validate download directory
|
# Validate download directory
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
@ -378,13 +406,13 @@ class GPT4All:
|
|||||||
headers = {}
|
headers = {}
|
||||||
if offset:
|
if offset:
|
||||||
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
|
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
|
||||||
headers['Range'] = f'bytes={offset}-' # resume incomplete response
|
headers["Range"] = f"bytes={offset}-" # resume incomplete response
|
||||||
headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges
|
headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges
|
||||||
response = requests.get(url, stream=True, headers=headers)
|
response = requests.get(url, stream=True, headers=headers)
|
||||||
if response.status_code not in (200, 206):
|
if response.status_code not in (200, 206):
|
||||||
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
|
raise ValueError(f"Request failed: HTTP {response.status_code} {response.reason}")
|
||||||
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
|
if offset and (response.status_code != 206 or str(offset) not in response.headers.get("Content-Range", "")):
|
||||||
raise ValueError('Connection was interrupted and server does not support range requests')
|
raise ValueError("Connection was interrupted and server does not support range requests")
|
||||||
if (enc := response.headers.get("Content-Encoding")) is not None:
|
if (enc := response.headers.get("Content-Encoding")) is not None:
|
||||||
raise ValueError(f"Expected identity Content-Encoding, got {enc}")
|
raise ValueError(f"Expected identity Content-Encoding, got {enc}")
|
||||||
return response
|
return response
|
||||||
@ -483,19 +511,19 @@ class GPT4All:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt : str,
|
||||||
*,
|
*,
|
||||||
max_tokens: int = 200,
|
max_tokens : int = 200,
|
||||||
temp: float = 0.7,
|
temp : float = 0.7,
|
||||||
top_k: int = 40,
|
top_k : int = 40,
|
||||||
top_p: float = 0.4,
|
top_p : float = 0.4,
|
||||||
min_p: float = 0.0,
|
min_p : float = 0.0,
|
||||||
repeat_penalty: float = 1.18,
|
repeat_penalty : float = 1.18,
|
||||||
repeat_last_n: int = 64,
|
repeat_last_n : int = 64,
|
||||||
n_batch: int = 8,
|
n_batch : int = 8,
|
||||||
n_predict: int | None = None,
|
n_predict : int | None = None,
|
||||||
streaming: bool = False,
|
streaming : bool = False,
|
||||||
callback: ResponseCallbackType = empty_response_callback,
|
callback : ResponseCallbackType = empty_response_callback,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Generate outputs from any GPT4All model.
|
Generate outputs from any GPT4All model.
|
||||||
@ -520,122 +548,94 @@ class GPT4All:
|
|||||||
|
|
||||||
# Preparing the model request
|
# Preparing the model request
|
||||||
generate_kwargs: dict[str, Any] = dict(
|
generate_kwargs: dict[str, Any] = dict(
|
||||||
temp=temp,
|
temp = temp,
|
||||||
top_k=top_k,
|
top_k = top_k,
|
||||||
top_p=top_p,
|
top_p = top_p,
|
||||||
min_p=min_p,
|
min_p = min_p,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty = repeat_penalty,
|
||||||
repeat_last_n=repeat_last_n,
|
repeat_last_n = repeat_last_n,
|
||||||
n_batch=n_batch,
|
n_batch = n_batch,
|
||||||
n_predict=n_predict if n_predict is not None else max_tokens,
|
n_predict = n_predict if n_predict is not None else max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._history is not None:
|
|
||||||
# check if there is only one message, i.e. system prompt:
|
|
||||||
reset = len(self._history) == 1
|
|
||||||
self._history.append({"role": "user", "content": prompt})
|
|
||||||
|
|
||||||
fct_func = self._format_chat_prompt_template.__func__ # type: ignore[attr-defined]
|
|
||||||
if fct_func is GPT4All._format_chat_prompt_template:
|
|
||||||
if reset:
|
|
||||||
# ingest system prompt
|
|
||||||
# use "%1%2" and not "%1" to avoid implicit whitespace
|
|
||||||
self.model.prompt_model(self._history[0]["content"], "%1%2",
|
|
||||||
empty_response_callback,
|
|
||||||
n_batch=n_batch, n_predict=0, reset_context=True, special=True)
|
|
||||||
prompt_template = self._current_prompt_template.format("%1", "%2")
|
|
||||||
else:
|
|
||||||
warnings.warn(
|
|
||||||
"_format_chat_prompt_template is deprecated. Please use a chat session with a prompt template.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
# special tokens won't be processed
|
|
||||||
prompt = self._format_chat_prompt_template(
|
|
||||||
self._history[-1:],
|
|
||||||
self._history[0]["content"] if reset else "",
|
|
||||||
)
|
|
||||||
prompt_template = "%1"
|
|
||||||
generate_kwargs["reset_context"] = reset
|
|
||||||
else:
|
|
||||||
prompt_template = "%1"
|
|
||||||
generate_kwargs["reset_context"] = True
|
|
||||||
|
|
||||||
# Prepare the callback, process the model response
|
# Prepare the callback, process the model response
|
||||||
output_collector: list[MessageType]
|
full_response = ""
|
||||||
output_collector = [
|
|
||||||
{"content": ""}
|
|
||||||
] # placeholder for the self._history if chat session is not activated
|
|
||||||
|
|
||||||
if self._history is not None:
|
def _callback_wrapper(token_id: int, response: str) -> bool:
|
||||||
self._history.append({"role": "assistant", "content": ""})
|
nonlocal full_response
|
||||||
output_collector = self._history
|
full_response += response
|
||||||
|
return callback(token_id, response)
|
||||||
|
|
||||||
def _callback_wrapper(
|
last_msg_rendered = prompt
|
||||||
callback: ResponseCallbackType,
|
if self._chat_session is not None:
|
||||||
output_collector: list[MessageType],
|
session = self._chat_session
|
||||||
) -> ResponseCallbackType:
|
def render(messages: list[MessageType]) -> str:
|
||||||
def _callback(token_id: int, response: str) -> bool:
|
return session.template.render(
|
||||||
nonlocal callback, output_collector
|
messages=messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
**self.model.special_tokens_map,
|
||||||
|
)
|
||||||
|
session.history.append(MessageType(role="user", content=prompt))
|
||||||
|
prompt = render(session.history)
|
||||||
|
if len(session.history) > 1:
|
||||||
|
last_msg_rendered = render(session.history[-1:])
|
||||||
|
|
||||||
output_collector[-1]["content"] += response
|
# Check request length
|
||||||
|
last_msg_len = self.model.count_prompt_tokens(last_msg_rendered)
|
||||||
return callback(token_id, response)
|
if last_msg_len > (limit := self.model.n_ctx - 4):
|
||||||
|
raise ValueError(f"Your message was too long and could not be processed ({last_msg_len} > {limit}).")
|
||||||
return _callback
|
|
||||||
|
|
||||||
# Send the request to the model
|
# Send the request to the model
|
||||||
if streaming:
|
if streaming:
|
||||||
return self.model.prompt_model_streaming(
|
def stream() -> Iterator[str]:
|
||||||
prompt,
|
yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs)
|
||||||
prompt_template,
|
if self._chat_session is not None:
|
||||||
_callback_wrapper(callback, output_collector),
|
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
|
||||||
**generate_kwargs,
|
return stream()
|
||||||
)
|
|
||||||
|
|
||||||
self.model.prompt_model(
|
self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs)
|
||||||
prompt,
|
if self._chat_session is not None:
|
||||||
prompt_template,
|
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
|
||||||
_callback_wrapper(callback, output_collector),
|
return full_response
|
||||||
**generate_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output_collector[-1]["content"]
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def chat_session(
|
def chat_session(
|
||||||
self,
|
self,
|
||||||
system_prompt: str | None = None,
|
system_message: str | Literal[False] | None = None,
|
||||||
prompt_template: str | None = None,
|
chat_template: str | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Context manager to hold an inference optimized chat session with a GPT4All model.
|
Context manager to hold an inference optimized chat session with a GPT4All model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
system_prompt: An initial instruction for the model.
|
system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None.
|
||||||
prompt_template: Template for the prompts with {0} being replaced by the user message.
|
chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if system_prompt is None:
|
if system_message is None:
|
||||||
system_prompt = self.config.get("systemPrompt", "")
|
system_message = self.config.get("systemMessage", False)
|
||||||
|
|
||||||
if prompt_template is None:
|
if chat_template is None:
|
||||||
if (tmpl := self.config.get("promptTemplate")) is None:
|
if "name" not in self.config:
|
||||||
warnings.warn("Use of a sideloaded model or allow_download=False without specifying a prompt template "
|
raise ValueError("For sideloaded models or with allow_download=False, you must specify a chat template.")
|
||||||
"is deprecated. Defaulting to Alpaca.", DeprecationWarning)
|
if "chatTemplate" not in self.config:
|
||||||
tmpl = DEFAULT_PROMPT_TEMPLATE
|
raise NotImplementedError("This model appears to have a built-in chat template, but loading it is not "
|
||||||
prompt_template = tmpl
|
"currently implemented. Please pass a template to chat_session() directly.")
|
||||||
|
if (tmpl := self.config["chatTemplate"]) is None:
|
||||||
|
raise ValueError(f"The model {self.config['name']!r} does not support chat.")
|
||||||
|
chat_template = tmpl
|
||||||
|
|
||||||
if re.search(r"%1(?![0-9])", prompt_template):
|
history = []
|
||||||
raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt "
|
if system_message is not False:
|
||||||
"placeholder, please use '{0}' instead.")
|
history.append(MessageType(role="system", content=system_message))
|
||||||
|
self._chat_session = ChatSession(
|
||||||
self._history = [{"role": "system", "content": system_prompt}]
|
template=_jinja_env.from_string(chat_template),
|
||||||
self._current_prompt_template = prompt_template
|
history=history,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
yield self
|
yield self
|
||||||
finally:
|
finally:
|
||||||
self._history = None
|
self._chat_session = None
|
||||||
self._current_prompt_template = "{0}"
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_gpus() -> list[str]:
|
def list_gpus() -> list[str]:
|
||||||
@ -647,43 +647,6 @@ class GPT4All:
|
|||||||
"""
|
"""
|
||||||
return LLModel.list_gpus()
|
return LLModel.list_gpus()
|
||||||
|
|
||||||
def _format_chat_prompt_template(
|
|
||||||
self,
|
|
||||||
messages: list[MessageType],
|
|
||||||
default_prompt_header: str = "",
|
|
||||||
default_prompt_footer: str = "",
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Helper method for building a prompt from list of messages using the self._current_prompt_template as a template for each message.
|
|
||||||
|
|
||||||
Warning:
|
|
||||||
This function was deprecated in version 2.3.0, and will be removed in a future release.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of dictionaries. Each dictionary should have a "role" key
|
|
||||||
with value of "system", "assistant", or "user" and a "content" key with a
|
|
||||||
string value. Messages are organized such that "system" messages are at top of prompt,
|
|
||||||
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
|
|
||||||
"Response: {content}".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
if message["role"] == "user":
|
|
||||||
user_message = self._current_prompt_template.format(message["content"])
|
|
||||||
full_prompt += user_message
|
|
||||||
if message["role"] == "assistant":
|
|
||||||
assistant_message = message["content"] + "\n"
|
|
||||||
full_prompt += assistant_message
|
|
||||||
|
|
||||||
full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else ""
|
|
||||||
|
|
||||||
return full_prompt
|
|
||||||
|
|
||||||
|
|
||||||
def append_extension_if_missing(model_name):
|
def append_extension_if_missing(model_name):
|
||||||
if not model_name.endswith((".bin", ".gguf")):
|
if not model_name.endswith((".bin", ".gguf")):
|
||||||
@ -696,7 +659,7 @@ class _HasFileno(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
def _fsync(fd: int | _HasFileno) -> None:
|
def _fsync(fd: int | _HasFileno) -> None:
|
||||||
if sys.platform == 'darwin':
|
if sys.platform == "darwin":
|
||||||
# Apple's fsync does not flush the drive write cache
|
# Apple's fsync does not flush the drive write cache
|
||||||
try:
|
try:
|
||||||
fcntl.fcntl(fd, fcntl.F_FULLFSYNC)
|
fcntl.fcntl(fd, fcntl.F_FULLFSYNC)
|
||||||
|
@ -14,6 +14,7 @@ nav:
|
|||||||
- 'Models' : 'gpt4all_desktop/models.md'
|
- 'Models' : 'gpt4all_desktop/models.md'
|
||||||
- 'LocalDocs' : 'gpt4all_desktop/localdocs.md'
|
- 'LocalDocs' : 'gpt4all_desktop/localdocs.md'
|
||||||
- 'Settings' : 'gpt4all_desktop/settings.md'
|
- 'Settings' : 'gpt4all_desktop/settings.md'
|
||||||
|
- 'Chat Templates' : 'gpt4all_desktop/chat_templates.md'
|
||||||
- 'Cookbook':
|
- 'Cookbook':
|
||||||
- 'Local AI Chat with Microsoft Excel': 'gpt4all_desktop/cookbook/use-local-ai-models-to-privately-chat-with-microsoft-excel.md'
|
- 'Local AI Chat with Microsoft Excel': 'gpt4all_desktop/cookbook/use-local-ai-models-to-privately-chat-with-microsoft-excel.md'
|
||||||
- 'Local AI Chat with your Google Drive': 'gpt4all_desktop/cookbook/use-local-ai-models-to-privately-chat-with-google-drive.md'
|
- 'Local AI Chat with your Google Drive': 'gpt4all_desktop/cookbook/use-local-ai-models-to-privately-chat-with-google-drive.md'
|
||||||
|
@ -88,9 +88,10 @@ setup(
|
|||||||
python_requires='>=3.8',
|
python_requires='>=3.8',
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
'importlib_resources; python_version < "3.9"',
|
||||||
|
'jinja2~=3.1',
|
||||||
'requests',
|
'requests',
|
||||||
'tqdm',
|
'tqdm',
|
||||||
'importlib_resources; python_version < "3.9"',
|
|
||||||
'typing-extensions>=4.3.0; python_version >= "3.9" and python_version < "3.11"',
|
'typing-extensions>=4.3.0; python_version >= "3.9" and python_version < "3.11"',
|
||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
|
@ -190,6 +190,7 @@ qt_add_executable(chat
|
|||||||
src/database.cpp src/database.h
|
src/database.cpp src/database.h
|
||||||
src/download.cpp src/download.h
|
src/download.cpp src/download.h
|
||||||
src/embllm.cpp src/embllm.h
|
src/embllm.cpp src/embllm.h
|
||||||
|
src/jinja_helpers.cpp src/jinja_helpers.h
|
||||||
src/llm.cpp src/llm.h
|
src/llm.cpp src/llm.h
|
||||||
src/localdocs.cpp src/localdocs.h
|
src/localdocs.cpp src/localdocs.h
|
||||||
src/localdocsmodel.cpp src/localdocsmodel.h
|
src/localdocsmodel.cpp src/localdocsmodel.h
|
||||||
@ -215,6 +216,7 @@ qt_add_qml_module(chat
|
|||||||
qml/ApplicationSettings.qml
|
qml/ApplicationSettings.qml
|
||||||
qml/ChatDrawer.qml
|
qml/ChatDrawer.qml
|
||||||
qml/ChatItemView.qml
|
qml/ChatItemView.qml
|
||||||
|
qml/ChatMessageButton.qml
|
||||||
qml/ChatView.qml
|
qml/ChatView.qml
|
||||||
qml/CollectionsDrawer.qml
|
qml/CollectionsDrawer.qml
|
||||||
qml/HomeView.qml
|
qml/HomeView.qml
|
||||||
@ -227,7 +229,7 @@ qt_add_qml_module(chat
|
|||||||
qml/PopupDialog.qml
|
qml/PopupDialog.qml
|
||||||
qml/SettingsView.qml
|
qml/SettingsView.qml
|
||||||
qml/StartupDialog.qml
|
qml/StartupDialog.qml
|
||||||
qml/SwitchModelDialog.qml
|
qml/ConfirmationDialog.qml
|
||||||
qml/Theme.qml
|
qml/Theme.qml
|
||||||
qml/ThumbsDownDialog.qml
|
qml/ThumbsDownDialog.qml
|
||||||
qml/Toast.qml
|
qml/Toast.qml
|
||||||
@ -386,7 +388,7 @@ target_include_directories(chat PRIVATE deps/usearch/include
|
|||||||
target_link_libraries(chat
|
target_link_libraries(chat
|
||||||
PRIVATE Qt6::Core Qt6::HttpServer Qt6::Pdf Qt6::Quick Qt6::Sql Qt6::Svg)
|
PRIVATE Qt6::Core Qt6::HttpServer Qt6::Pdf Qt6::Quick Qt6::Sql Qt6::Svg)
|
||||||
target_link_libraries(chat
|
target_link_libraries(chat
|
||||||
PRIVATE llmodel SingleApplication fmt::fmt duckx::duckx QXlsx)
|
PRIVATE llmodel SingleApplication fmt::fmt duckx::duckx QXlsx jinja2cpp)
|
||||||
|
|
||||||
if (APPLE)
|
if (APPLE)
|
||||||
target_link_libraries(chat PRIVATE ${COCOA_LIBRARY})
|
target_link_libraries(chat PRIVATE ${COCOA_LIBRARY})
|
||||||
|
@ -11,3 +11,5 @@ add_subdirectory(DuckX)
|
|||||||
|
|
||||||
set(QT_VERSION_MAJOR 6)
|
set(QT_VERSION_MAJOR 6)
|
||||||
add_subdirectory(QXlsx/QXlsx)
|
add_subdirectory(QXlsx/QXlsx)
|
||||||
|
|
||||||
|
add_subdirectory(Jinja2Cpp)
|
||||||
|
1
gpt4all-chat/deps/Jinja2Cpp
Submodule
1
gpt4all-chat/deps/Jinja2Cpp
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit b2a716798bfa63c7dae303fc1e272964c4e1f9ee
|
@ -1,3 +1 @@
|
|||||||
<svg width="32" height="32" viewBox="0 0 32 32" fill="none" xmlns="http://www.w3.org/2000/svg">
|
<svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" fill="#000000" viewBox="0 0 256 256"><path d="M227.31,73.37,182.63,28.68a16,16,0,0,0-22.63,0L36.69,152A15.86,15.86,0,0,0,32,163.31V208a16,16,0,0,0,16,16H92.69A15.86,15.86,0,0,0,104,219.31L227.31,96a16,16,0,0,0,0-22.63ZM92.69,208H48V163.31l88-88L180.69,120ZM192,108.68,147.31,64l24-24L216,84.68Z"></path></svg>
|
||||||
<path d="M28.4138 9.17125L22.8288 3.585C22.643 3.39924 22.4225 3.25188 22.1799 3.15134C21.9372 3.0508 21.6771 2.99905 21.4144 2.99905C21.1517 2.99905 20.8916 3.0508 20.6489 3.15134C20.4062 3.25188 20.1857 3.39924 20 3.585L4.58626 19C4.39973 19.185 4.25185 19.4053 4.15121 19.648C4.05057 19.8907 3.99917 20.151 4.00001 20.4138V26C4.00001 26.5304 4.21072 27.0391 4.5858 27.4142C4.96087 27.7893 5.46958 28 6.00001 28H11.5863C11.849 28.0008 12.1093 27.9494 12.352 27.8488C12.5947 27.7482 12.815 27.6003 13 27.4138L28.4138 12C28.5995 11.8143 28.7469 11.5938 28.8474 11.3511C28.948 11.1084 28.9997 10.8483 28.9997 10.5856C28.9997 10.3229 28.948 10.0628 28.8474 9.82015C28.7469 9.57747 28.5995 9.35698 28.4138 9.17125ZM6.41376 20L17 9.41375L19.0863 11.5L8.50001 22.085L6.41376 20ZM6.00001 22.4138L9.58626 26H6.00001V22.4138ZM12 25.5863L9.91376 23.5L20.5 12.9138L22.5863 15L12 25.5863ZM24 13.5863L18.4138 8L21.4138 5L27 10.585L24 13.5863Z" fill="black"/>
|
|
||||||
</svg>
|
|
Before Width: | Height: | Size: 1.0 KiB After Width: | Height: | Size: 372 B |
@ -29,7 +29,8 @@
|
|||||||
"description": "<ul><li>Fast responses</li><li>Instruct model</li><li>Multilingual dialogue use</li><li>Agentic system capable</li><li>Trained by Meta</li><li>License: <a href=\"https://llama.meta.com/llama3_2/license/\">Meta Llama 3.2 Community License</a></li></ul>",
|
"description": "<ul><li>Fast responses</li><li>Instruct model</li><li>Multilingual dialogue use</li><li>Agentic system capable</li><li>Trained by Meta</li><li>License: <a href=\"https://llama.meta.com/llama3_2/license/\">Meta Llama 3.2 Community License</a></li></ul>",
|
||||||
"url": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_0.gguf",
|
"url": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_0.gguf",
|
||||||
"promptTemplate": "<|start_header_id|>user<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2",
|
"promptTemplate": "<|start_header_id|>user<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2",
|
||||||
"systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>"
|
"systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>",
|
||||||
|
"chatTemplate": "{{- bos_token }}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{%- for message in messages %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "c",
|
"order": "c",
|
||||||
@ -45,7 +46,8 @@
|
|||||||
"description": "<ul><li>Fast responses</li><li>Instruct model</li><li>Multilingual dialogue use</li><li>Agentic system capable</li><li>Trained by Meta</li><li>License: <a href=\"https://llama.meta.com/llama3_2/license/\">Meta Llama 3.2 Community License</a></li></ul>",
|
"description": "<ul><li>Fast responses</li><li>Instruct model</li><li>Multilingual dialogue use</li><li>Agentic system capable</li><li>Trained by Meta</li><li>License: <a href=\"https://llama.meta.com/llama3_2/license/\">Meta Llama 3.2 Community License</a></li></ul>",
|
||||||
"url": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf",
|
"url": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf",
|
||||||
"promptTemplate": "<|start_header_id|>user<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2",
|
"promptTemplate": "<|start_header_id|>user<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2",
|
||||||
"systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>"
|
"systemPrompt": "<|start_header_id|>system<|end_header_id|>\nCutting Knowledge Date: December 2023\n\nYou are a helpful assistant.<|eot_id|>",
|
||||||
|
"chatTemplate": "{{- bos_token }}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{%- for message in messages %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "d",
|
"order": "d",
|
||||||
@ -77,7 +79,8 @@
|
|||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"description": "<strong>Strong overall fast instruction following model</strong><br><ul><li>Fast responses</li><li>Trained by Mistral AI<li>Uncensored</li><li>Licensed for commercial use</li></ul>",
|
"description": "<strong>Strong overall fast instruction following model</strong><br><ul><li>Fast responses</li><li>Trained by Mistral AI<li>Uncensored</li><li>Licensed for commercial use</li></ul>",
|
||||||
"url": "https://gpt4all.io/models/gguf/mistral-7b-instruct-v0.1.Q4_0.gguf",
|
"url": "https://gpt4all.io/models/gguf/mistral-7b-instruct-v0.1.Q4_0.gguf",
|
||||||
"promptTemplate": "[INST] %1 [/INST]"
|
"promptTemplate": "[INST] %1 [/INST]",
|
||||||
|
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_start = 1 %}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if (message['role'] == 'user') != ((loop.index0 - loop_start) % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "f",
|
"order": "f",
|
||||||
@ -125,7 +128,8 @@
|
|||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"description": "<strong>Very fast model with good quality</strong><br><ul><li>Fastest responses</li><li>Instruction based</li><li>Trained by TII<li>Finetuned by Nomic AI<li>Licensed for commercial use</ul>",
|
"description": "<strong>Very fast model with good quality</strong><br><ul><li>Fastest responses</li><li>Instruction based</li><li>Trained by TII<li>Finetuned by Nomic AI<li>Licensed for commercial use</ul>",
|
||||||
"url": "https://gpt4all.io/models/gguf/gpt4all-falcon-newbpe-q4_0.gguf",
|
"url": "https://gpt4all.io/models/gguf/gpt4all-falcon-newbpe-q4_0.gguf",
|
||||||
"promptTemplate": "### Instruction:\n%1\n\n### Response:\n"
|
"promptTemplate": "### Instruction:\n%1\n\n### Response:\n",
|
||||||
|
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### User: ' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Assistant: ' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Assistant:' }}\n{%- endif %}"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "i",
|
"order": "i",
|
||||||
@ -140,7 +144,8 @@
|
|||||||
"type": "LLaMA2",
|
"type": "LLaMA2",
|
||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"description": "<ul><li>Instruction based<li>Trained by Microsoft<li>Cannot be used commercially</ul>",
|
"description": "<ul><li>Instruction based<li>Trained by Microsoft<li>Cannot be used commercially</ul>",
|
||||||
"url": "https://gpt4all.io/models/gguf/orca-2-7b.Q4_0.gguf"
|
"url": "https://gpt4all.io/models/gguf/orca-2-7b.Q4_0.gguf",
|
||||||
|
"chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "j",
|
"order": "j",
|
||||||
@ -155,7 +160,8 @@
|
|||||||
"type": "LLaMA2",
|
"type": "LLaMA2",
|
||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"description": "<ul><li>Instruction based<li>Trained by Microsoft<li>Cannot be used commercially</ul>",
|
"description": "<ul><li>Instruction based<li>Trained by Microsoft<li>Cannot be used commercially</ul>",
|
||||||
"url": "https://gpt4all.io/models/gguf/orca-2-13b.Q4_0.gguf"
|
"url": "https://gpt4all.io/models/gguf/orca-2-13b.Q4_0.gguf",
|
||||||
|
"chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "k",
|
"order": "k",
|
||||||
@ -170,7 +176,9 @@
|
|||||||
"type": "LLaMA2",
|
"type": "LLaMA2",
|
||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"description": "<strong>Strong overall larger model</strong><br><ul><li>Instruction based<li>Gives very long responses<li>Finetuned with only 1k of high-quality data<li>Trained by Microsoft and Peking University<li>Cannot be used commercially</ul>",
|
"description": "<strong>Strong overall larger model</strong><br><ul><li>Instruction based<li>Gives very long responses<li>Finetuned with only 1k of high-quality data<li>Trained by Microsoft and Peking University<li>Cannot be used commercially</ul>",
|
||||||
"url": "https://gpt4all.io/models/gguf/wizardlm-13b-v1.2.Q4_0.gguf"
|
"url": "https://gpt4all.io/models/gguf/wizardlm-13b-v1.2.Q4_0.gguf",
|
||||||
|
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + ' ' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in loop_messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- 'USER: ' + message['content'] }}\n {%- elif message['role'] == 'assistant' %}\n {{- 'ASSISTANT: ' + message['content'] }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- if (loop.index0 - loop_start) % 2 == 0 %}\n {{- ' ' }}\n {%- else %}\n {{- eos_token }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- 'ASSISTANT:' }}\n{%- endif %}",
|
||||||
|
"systemMessage": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "l",
|
"order": "l",
|
||||||
@ -186,7 +194,8 @@
|
|||||||
"description": "<strong>Ghost 7B v0.9.1</strong> fast, powerful and smooth for Vietnamese and English languages.",
|
"description": "<strong>Ghost 7B v0.9.1</strong> fast, powerful and smooth for Vietnamese and English languages.",
|
||||||
"url": "https://huggingface.co/lamhieu/ghost-7b-v0.9.1-gguf/resolve/main/ghost-7b-v0.9.1-Q4_0.gguf",
|
"url": "https://huggingface.co/lamhieu/ghost-7b-v0.9.1-gguf/resolve/main/ghost-7b-v0.9.1-Q4_0.gguf",
|
||||||
"promptTemplate": "<|user|>\n%1</s>\n<|assistant|>\n%2</s>\n",
|
"promptTemplate": "<|user|>\n%1</s>\n<|assistant|>\n%2</s>\n",
|
||||||
"systemPrompt": "<|system|>\nYou are Ghost created by Lam Hieu. You are a helpful and knowledgeable assistant. You like to help and always give honest information, in its original language. In communication, you are always respectful, equal and promote positive behavior.\n</s>"
|
"systemPrompt": "<|system|>\nYou are Ghost created by Lam Hieu. You are a helpful and knowledgeable assistant. You like to help and always give honest information, in its original language. In communication, you are always respectful, equal and promote positive behavior.\n</s>",
|
||||||
|
"systemMessage": "You are Ghost created by Lam Hieu. You are a helpful and knowledgeable assistant. You like to help and always give honest information, in its original language. In communication, you are always respectful, equal and promote positive behavior."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "m",
|
"order": "m",
|
||||||
@ -202,7 +211,8 @@
|
|||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"description": "<strong>Extremely good model</strong><br><ul><li>Instruction based<li>Gives long responses<li>Curated with 300,000 uncensored instructions<li>Trained by Nous Research<li>Cannot be used commercially</ul>",
|
"description": "<strong>Extremely good model</strong><br><ul><li>Instruction based<li>Gives long responses<li>Curated with 300,000 uncensored instructions<li>Trained by Nous Research<li>Cannot be used commercially</ul>",
|
||||||
"url": "https://gpt4all.io/models/gguf/nous-hermes-llama2-13b.Q4_0.gguf",
|
"url": "https://gpt4all.io/models/gguf/nous-hermes-llama2-13b.Q4_0.gguf",
|
||||||
"promptTemplate": "### Instruction:\n%1\n\n### Response:\n"
|
"promptTemplate": "### Instruction:\n%1\n\n### Response:\n",
|
||||||
|
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### Instruction:\\n' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Response:\\n' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Instruction:\\n' }}\n{%- endif %}"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "n",
|
"order": "n",
|
||||||
@ -217,7 +227,9 @@
|
|||||||
"type": "LLaMA",
|
"type": "LLaMA",
|
||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"description": "<strong>Very good overall model</strong><br><ul><li>Instruction based<li>Based on the same dataset as Groovy<li>Slower than Groovy, with higher quality responses<li>Trained by Nomic AI<li>Cannot be used commercially</ul>",
|
"description": "<strong>Very good overall model</strong><br><ul><li>Instruction based<li>Based on the same dataset as Groovy<li>Slower than Groovy, with higher quality responses<li>Trained by Nomic AI<li>Cannot be used commercially</ul>",
|
||||||
"url": "https://gpt4all.io/models/gguf/gpt4all-13b-snoozy-q4_0.gguf"
|
"url": "https://gpt4all.io/models/gguf/gpt4all-13b-snoozy-q4_0.gguf",
|
||||||
|
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### Instruction:\\n' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Response:\\n' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Response:\\n' }}\n{%- endif %}",
|
||||||
|
"systemMessage": "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "o",
|
"order": "o",
|
||||||
@ -234,7 +246,8 @@
|
|||||||
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
|
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
|
||||||
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat-newbpe-q4_0.gguf",
|
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat-newbpe-q4_0.gguf",
|
||||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
|
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
|
||||||
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n"
|
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n",
|
||||||
|
"chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "p",
|
"order": "p",
|
||||||
@ -250,7 +263,8 @@
|
|||||||
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
|
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
|
||||||
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat.gguf4.Q4_0.gguf",
|
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat.gguf4.Q4_0.gguf",
|
||||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
|
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
|
||||||
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n"
|
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n",
|
||||||
|
"chatTemplate": "{%- for message in messages %}\n {{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n' }}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "q",
|
"order": "q",
|
||||||
@ -282,7 +296,8 @@
|
|||||||
"description": "<strong>Small version of new model with novel dataset</strong><br><ul><li>Very fast responses</li><li>Instruction based</li><li>Explain tuned datasets</li><li>Orca Research Paper dataset construction approaches</li><li>Cannot be used commercially</li></ul>",
|
"description": "<strong>Small version of new model with novel dataset</strong><br><ul><li>Very fast responses</li><li>Instruction based</li><li>Explain tuned datasets</li><li>Orca Research Paper dataset construction approaches</li><li>Cannot be used commercially</li></ul>",
|
||||||
"url": "https://gpt4all.io/models/gguf/orca-mini-3b-gguf2-q4_0.gguf",
|
"url": "https://gpt4all.io/models/gguf/orca-mini-3b-gguf2-q4_0.gguf",
|
||||||
"promptTemplate": "### User:\n%1\n\n### Response:\n",
|
"promptTemplate": "### User:\n%1\n\n### Response:\n",
|
||||||
"systemPrompt": "### System:\nYou are an AI assistant that follows instruction extremely well. Help as much as you can.\n\n"
|
"systemPrompt": "### System:\nYou are an AI assistant that follows instruction extremely well. Help as much as you can.\n\n",
|
||||||
|
"chatTemplate": "{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {{- '### System:\\n' + messages[0]['content'] + '\\n\\n' }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if message['role'] == 'user' %}\n {{- '### User:\\n' + message['content'] + '\\n\\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '### Response:\\n' + message['content'] + '\\n\\n' }}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '### Response:\\n' }}\n{%- endif %}"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "s",
|
"order": "s",
|
||||||
@ -299,7 +314,8 @@
|
|||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"promptTemplate": "%1",
|
"promptTemplate": "%1",
|
||||||
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>Licensed for commercial use<li>WARNING: Not available for chat GUI</ul>",
|
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>Licensed for commercial use<li>WARNING: Not available for chat GUI</ul>",
|
||||||
"url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-newbpe-q4_0.gguf"
|
"url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-newbpe-q4_0.gguf",
|
||||||
|
"chatTemplate": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "t",
|
"order": "t",
|
||||||
@ -316,7 +332,8 @@
|
|||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"promptTemplate": "%1",
|
"promptTemplate": "%1",
|
||||||
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</ul>",
|
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</ul>",
|
||||||
"url": "https://gpt4all.io/models/gguf/starcoder-newbpe-q4_0.gguf"
|
"url": "https://gpt4all.io/models/gguf/starcoder-newbpe-q4_0.gguf",
|
||||||
|
"chatTemplate": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "u",
|
"order": "u",
|
||||||
@ -333,7 +350,8 @@
|
|||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"promptTemplate": "%1",
|
"promptTemplate": "%1",
|
||||||
"description": "<strong>Trained on collection of Python and TypeScript</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</li>",
|
"description": "<strong>Trained on collection of Python and TypeScript</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</li>",
|
||||||
"url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf"
|
"url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf",
|
||||||
|
"chatTemplate": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "v",
|
"order": "v",
|
||||||
@ -351,7 +369,8 @@
|
|||||||
"embeddingModel": true,
|
"embeddingModel": true,
|
||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
|
"description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
|
||||||
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf"
|
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf",
|
||||||
|
"chatTemplate": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "w",
|
"order": "w",
|
||||||
@ -367,7 +386,8 @@
|
|||||||
"type": "Bert",
|
"type": "Bert",
|
||||||
"embeddingModel": true,
|
"embeddingModel": true,
|
||||||
"description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
|
"description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
|
||||||
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2.gguf2.f16.gguf"
|
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2.gguf2.f16.gguf",
|
||||||
|
"chatTemplate": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "x",
|
"order": "x",
|
||||||
@ -383,7 +403,9 @@
|
|||||||
"description": "<strong>Mistral-based model for German-language applications</strong><br><ul><li>Fast responses</li><li>Chat based model</li><li>Trained by ellamind<li>Finetuned on German instruction and chat data</a><li>Licensed for commercial use</ul>",
|
"description": "<strong>Mistral-based model for German-language applications</strong><br><ul><li>Fast responses</li><li>Chat based model</li><li>Trained by ellamind<li>Finetuned on German instruction and chat data</a><li>Licensed for commercial use</ul>",
|
||||||
"url": "https://huggingface.co/TheBloke/em_german_mistral_v01-GGUF/resolve/main/em_german_mistral_v01.Q4_0.gguf",
|
"url": "https://huggingface.co/TheBloke/em_german_mistral_v01-GGUF/resolve/main/em_german_mistral_v01.Q4_0.gguf",
|
||||||
"promptTemplate": "USER: %1 ASSISTANT: ",
|
"promptTemplate": "USER: %1 ASSISTANT: ",
|
||||||
"systemPrompt": "Du bist ein hilfreicher Assistent. "
|
"systemPrompt": "Du bist ein hilfreicher Assistent. ",
|
||||||
|
"chatTemplate": "{%- set system_message = false %}\n{%- if messages[0]['role'] == 'system' %}\n {%- set loop_start = 1 %}\n {%- set system_message = true %}\n {{- messages[0]['content'] }}\n{%- else %}\n {%- set loop_start = 0 %}\n{%- endif %}\n{%- for message in messages %}\n {%- if loop.index0 >= loop_start %}\n {%- if (not loop.first) or (system_message is not none) %}\n {{- ' ' }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {{- 'USER: ' + message['content'] }}\n {%- elif message['role'] == 'assistant' %}\n {{- 'ASSISTANT: ' + message['content'] }}\n {%- else %}\n {{- raise_exception('After the optional system message, conversation roles must be either user or assistant.') }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {%- if messages %}\n {{- ' ' }}\n {%- endif %}\n {{- 'ASSISTANT:' }}\n{%- endif %}",
|
||||||
|
"systemMessage": "Du bist ein hilfreicher Assistent."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "y",
|
"order": "y",
|
||||||
@ -400,7 +422,8 @@
|
|||||||
"embeddingModel": true,
|
"embeddingModel": true,
|
||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"description": "nomic-embed-text-v1",
|
"description": "nomic-embed-text-v1",
|
||||||
"url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.f16.gguf"
|
"url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.f16.gguf",
|
||||||
|
"chatTemplate": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "z",
|
"order": "z",
|
||||||
@ -417,7 +440,8 @@
|
|||||||
"embeddingModel": true,
|
"embeddingModel": true,
|
||||||
"systemPrompt": "",
|
"systemPrompt": "",
|
||||||
"description": "nomic-embed-text-v1.5",
|
"description": "nomic-embed-text-v1.5",
|
||||||
"url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.5.f16.gguf"
|
"url": "https://gpt4all.io/models/gguf/nomic-embed-text-v1.5.f16.gguf",
|
||||||
|
"chatTemplate": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"order": "zzz",
|
"order": "zzz",
|
||||||
|
@ -10,7 +10,7 @@ import network
|
|||||||
import llm
|
import llm
|
||||||
|
|
||||||
MySettingsTab {
|
MySettingsTab {
|
||||||
onRestoreDefaultsClicked: {
|
onRestoreDefaults: {
|
||||||
MySettings.restoreApplicationDefaults();
|
MySettings.restoreApplicationDefaults();
|
||||||
}
|
}
|
||||||
title: qsTr("Application")
|
title: qsTr("Application")
|
||||||
@ -486,23 +486,6 @@ MySettingsTab {
|
|||||||
Accessible.name: nThreadsLabel.text
|
Accessible.name: nThreadsLabel.text
|
||||||
Accessible.description: ToolTip.text
|
Accessible.description: ToolTip.text
|
||||||
}
|
}
|
||||||
MySettingsLabel {
|
|
||||||
id: saveChatsContextLabel
|
|
||||||
text: qsTr("Save Chat Context")
|
|
||||||
helpText: qsTr("Save the chat model's state to disk for faster loading. WARNING: Uses ~2GB per chat.")
|
|
||||||
Layout.row: 12
|
|
||||||
Layout.column: 0
|
|
||||||
}
|
|
||||||
MyCheckBox {
|
|
||||||
id: saveChatsContextBox
|
|
||||||
Layout.row: 12
|
|
||||||
Layout.column: 2
|
|
||||||
Layout.alignment: Qt.AlignRight
|
|
||||||
checked: MySettings.saveChatsContext
|
|
||||||
onClicked: {
|
|
||||||
MySettings.saveChatsContext = !MySettings.saveChatsContext
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MySettingsLabel {
|
MySettingsLabel {
|
||||||
id: trayLabel
|
id: trayLabel
|
||||||
text: qsTr("Enable System Tray")
|
text: qsTr("Enable System Tray")
|
||||||
|
@ -8,8 +8,23 @@ import QtQuick.Layouts
|
|||||||
import gpt4all
|
import gpt4all
|
||||||
import mysettings
|
import mysettings
|
||||||
|
|
||||||
|
ColumnLayout {
|
||||||
|
|
||||||
|
property var inputBoxText: null
|
||||||
|
signal setInputBoxText(text: string)
|
||||||
|
|
||||||
|
Item {
|
||||||
|
|
||||||
|
Layout.fillWidth: true
|
||||||
|
Layout.maximumWidth: parent.width
|
||||||
|
Layout.preferredHeight: gridLayout.height
|
||||||
|
|
||||||
|
HoverHandler { id: hoverArea }
|
||||||
|
|
||||||
GridLayout {
|
GridLayout {
|
||||||
rows: 5
|
id: gridLayout
|
||||||
|
anchors.left: parent.left
|
||||||
|
anchors.right: parent.right
|
||||||
columns: 2
|
columns: 2
|
||||||
|
|
||||||
Item {
|
Item {
|
||||||
@ -40,7 +55,7 @@ GridLayout {
|
|||||||
to: 360
|
to: 360
|
||||||
duration: 1000
|
duration: 1000
|
||||||
loops: Animation.Infinite
|
loops: Animation.Infinite
|
||||||
running: currentResponse && (currentChat.responseInProgress || currentChat.restoringFromText)
|
running: isCurrentResponse && currentChat.responseInProgress
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -73,13 +88,11 @@ GridLayout {
|
|||||||
color: theme.mutedTextColor
|
color: theme.mutedTextColor
|
||||||
}
|
}
|
||||||
RowLayout {
|
RowLayout {
|
||||||
visible: currentResponse && ((value === "" && currentChat.responseInProgress) || currentChat.restoringFromText)
|
visible: isCurrentResponse && (value === "" && currentChat.responseInProgress)
|
||||||
Text {
|
Text {
|
||||||
color: theme.mutedTextColor
|
color: theme.mutedTextColor
|
||||||
font.pixelSize: theme.fontSizeLarger
|
font.pixelSize: theme.fontSizeLarger
|
||||||
text: {
|
text: {
|
||||||
if (currentChat.restoringFromText)
|
|
||||||
return qsTr("restoring from text ...");
|
|
||||||
switch (currentChat.responseState) {
|
switch (currentChat.responseState) {
|
||||||
case Chat.ResponseStopped: return qsTr("response stopped ...");
|
case Chat.ResponseStopped: return qsTr("response stopped ...");
|
||||||
case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: %1 ...").arg(currentChat.collectionList.join(", "));
|
case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: %1 ...").arg(currentChat.collectionList.join(", "));
|
||||||
@ -99,10 +112,11 @@ GridLayout {
|
|||||||
Layout.row: 1
|
Layout.row: 1
|
||||||
Layout.column: 1
|
Layout.column: 1
|
||||||
Layout.fillWidth: true
|
Layout.fillWidth: true
|
||||||
spacing: 20
|
spacing: 10
|
||||||
Flow {
|
Flow {
|
||||||
id: attachedUrlsFlow
|
id: attachedUrlsFlow
|
||||||
Layout.fillWidth: true
|
Layout.fillWidth: true
|
||||||
|
Layout.bottomMargin: 10
|
||||||
spacing: 10
|
spacing: 10
|
||||||
visible: promptAttachments.length !== 0
|
visible: promptAttachments.length !== 0
|
||||||
Repeater {
|
Repeater {
|
||||||
@ -156,7 +170,7 @@ GridLayout {
|
|||||||
focus: false
|
focus: false
|
||||||
readOnly: true
|
readOnly: true
|
||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
cursorVisible: currentResponse ? currentChat.responseInProgress : false
|
cursorVisible: isCurrentResponse ? currentChat.responseInProgress : false
|
||||||
cursorPosition: text.length
|
cursorPosition: text.length
|
||||||
TapHandler {
|
TapHandler {
|
||||||
id: tapHandler
|
id: tapHandler
|
||||||
@ -183,12 +197,12 @@ GridLayout {
|
|||||||
}
|
}
|
||||||
|
|
||||||
onLinkActivated: function(link) {
|
onLinkActivated: function(link) {
|
||||||
if (!currentResponse || !currentChat.responseInProgress)
|
if (!isCurrentResponse || !currentChat.responseInProgress)
|
||||||
Qt.openUrlExternally(link)
|
Qt.openUrlExternally(link)
|
||||||
}
|
}
|
||||||
|
|
||||||
onLinkHovered: function (link) {
|
onLinkHovered: function (link) {
|
||||||
if (!currentResponse || !currentChat.responseInProgress)
|
if (!isCurrentResponse || !currentChat.responseInProgress)
|
||||||
statusBar.externalHoveredLink = link
|
statusBar.externalHoveredLink = link
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -239,13 +253,19 @@ GridLayout {
|
|||||||
textProcessor.setValue(value);
|
textProcessor.setValue(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
property bool textProcessorReady: false
|
||||||
|
|
||||||
Component.onCompleted: {
|
Component.onCompleted: {
|
||||||
resetChatViewTextProcessor();
|
resetChatViewTextProcessor();
|
||||||
chatModel.valueChanged.connect(function(i, value) {
|
textProcessorReady = true;
|
||||||
if (index === i)
|
}
|
||||||
|
|
||||||
|
Connections {
|
||||||
|
target: chatModel
|
||||||
|
function onValueChanged(i, value) {
|
||||||
|
if (myTextArea.textProcessorReady && index === i)
|
||||||
textProcessor.setValue(value);
|
textProcessor.setValue(value);
|
||||||
}
|
}
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Connections {
|
Connections {
|
||||||
@ -282,67 +302,6 @@ GridLayout {
|
|||||||
Network.sendConversation(currentChat.id, getConversationJson());
|
Network.sendConversation(currentChat.id, getConversationJson());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Column {
|
|
||||||
Layout.alignment: Qt.AlignRight
|
|
||||||
Layout.rightMargin: 15
|
|
||||||
visible: name === "Response: " &&
|
|
||||||
(!currentResponse || !currentChat.responseInProgress) && MySettings.networkIsActive
|
|
||||||
spacing: 10
|
|
||||||
|
|
||||||
Item {
|
|
||||||
width: childrenRect.width
|
|
||||||
height: childrenRect.height
|
|
||||||
MyToolButton {
|
|
||||||
id: thumbsUp
|
|
||||||
width: 24
|
|
||||||
height: 24
|
|
||||||
imageWidth: width
|
|
||||||
imageHeight: height
|
|
||||||
opacity: thumbsUpState || thumbsUpState == thumbsDownState ? 1.0 : 0.2
|
|
||||||
source: "qrc:/gpt4all/icons/thumbs_up.svg"
|
|
||||||
Accessible.name: qsTr("Thumbs up")
|
|
||||||
Accessible.description: qsTr("Gives a thumbs up to the response")
|
|
||||||
onClicked: {
|
|
||||||
if (thumbsUpState && !thumbsDownState)
|
|
||||||
return
|
|
||||||
|
|
||||||
chatModel.updateNewResponse(index, "")
|
|
||||||
chatModel.updateThumbsUpState(index, true)
|
|
||||||
chatModel.updateThumbsDownState(index, false)
|
|
||||||
Network.sendConversation(currentChat.id, getConversationJson());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MyToolButton {
|
|
||||||
id: thumbsDown
|
|
||||||
anchors.top: thumbsUp.top
|
|
||||||
anchors.topMargin: 3
|
|
||||||
anchors.left: thumbsUp.right
|
|
||||||
anchors.leftMargin: 3
|
|
||||||
width: 24
|
|
||||||
height: 24
|
|
||||||
imageWidth: width
|
|
||||||
imageHeight: height
|
|
||||||
checked: thumbsDownState
|
|
||||||
opacity: thumbsDownState || thumbsUpState == thumbsDownState ? 1.0 : 0.2
|
|
||||||
transform: [
|
|
||||||
Matrix4x4 {
|
|
||||||
matrix: Qt.matrix4x4(-1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1)
|
|
||||||
},
|
|
||||||
Translate {
|
|
||||||
x: thumbsDown.width
|
|
||||||
}
|
|
||||||
]
|
|
||||||
source: "qrc:/gpt4all/icons/thumbs_down.svg"
|
|
||||||
Accessible.name: qsTr("Thumbs down")
|
|
||||||
Accessible.description: qsTr("Opens thumbs down dialog")
|
|
||||||
onClicked: {
|
|
||||||
thumbsDownDialog.open()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Item {
|
Item {
|
||||||
@ -353,11 +312,13 @@ GridLayout {
|
|||||||
Layout.preferredWidth: childrenRect.width
|
Layout.preferredWidth: childrenRect.width
|
||||||
Layout.preferredHeight: childrenRect.height
|
Layout.preferredHeight: childrenRect.height
|
||||||
visible: {
|
visible: {
|
||||||
|
if (name !== "Response: ")
|
||||||
|
return false
|
||||||
if (consolidatedSources.length === 0)
|
if (consolidatedSources.length === 0)
|
||||||
return false
|
return false
|
||||||
if (!MySettings.localDocsShowReferences)
|
if (!MySettings.localDocsShowReferences)
|
||||||
return false
|
return false
|
||||||
if (currentResponse && currentChat.responseInProgress
|
if (isCurrentResponse && currentChat.responseInProgress
|
||||||
&& currentChat.responseState !== Chat.GeneratingQuestions )
|
&& currentChat.responseState !== Chat.GeneratingQuestions )
|
||||||
return false
|
return false
|
||||||
return true
|
return true
|
||||||
@ -443,7 +404,7 @@ GridLayout {
|
|||||||
return false
|
return false
|
||||||
if (!MySettings.localDocsShowReferences)
|
if (!MySettings.localDocsShowReferences)
|
||||||
return false
|
return false
|
||||||
if (currentResponse && currentChat.responseInProgress
|
if (isCurrentResponse && currentChat.responseInProgress
|
||||||
&& currentChat.responseState !== Chat.GeneratingQuestions )
|
&& currentChat.responseState !== Chat.GeneratingQuestions )
|
||||||
return false
|
return false
|
||||||
return true
|
return true
|
||||||
@ -566,8 +527,139 @@ GridLayout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ConfirmationDialog {
|
||||||
|
id: editPromptDialog
|
||||||
|
dialogTitle: qsTr("Edit this prompt?")
|
||||||
|
description: qsTr("The existing response and all later messages will be permanently erased.")
|
||||||
|
onAccepted: {
|
||||||
|
const msg = currentChat.popPrompt(index);
|
||||||
|
if (msg !== null)
|
||||||
|
setInputBoxText(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfirmationDialog {
|
||||||
|
id: redoResponseDialog
|
||||||
|
dialogTitle: qsTr("Redo this response?")
|
||||||
|
description: qsTr("The existing response and all later messages will be permanently erased.")
|
||||||
|
onAccepted: currentChat.regenerateResponse(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
RowLayout {
|
||||||
|
id: buttonRow
|
||||||
|
Layout.row: 4
|
||||||
|
Layout.column: 1
|
||||||
|
Layout.maximumWidth: parent.width
|
||||||
|
Layout.fillWidth: false
|
||||||
|
Layout.alignment: Qt.AlignLeft | Qt.AlignTop
|
||||||
|
spacing: 3
|
||||||
|
visible: !isCurrentResponse || !currentChat.responseInProgress
|
||||||
|
enabled: opacity > 0
|
||||||
|
opacity: hoverArea.hovered
|
||||||
|
readonly property var canModify: !currentChat.isServer && currentChat.isModelLoaded && !currentChat.responseInProgress
|
||||||
|
|
||||||
|
Behavior on opacity {
|
||||||
|
OpacityAnimator { duration: 30 }
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatMessageButton {
|
||||||
|
visible: parent.canModify && model.name === "Prompt: "
|
||||||
|
Layout.maximumWidth: 24
|
||||||
|
Layout.maximumHeight: 24
|
||||||
|
Layout.alignment: Qt.AlignVCenter
|
||||||
|
Layout.fillWidth: false
|
||||||
|
source: "qrc:/gpt4all/icons/edit.svg"
|
||||||
|
onClicked: {
|
||||||
|
if (inputBoxText === "")
|
||||||
|
editPromptDialog.open();
|
||||||
|
}
|
||||||
|
name: qsTr("Edit")
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatMessageButton {
|
||||||
|
visible: parent.canModify && model.name === "Response: "
|
||||||
|
Layout.maximumWidth: 24
|
||||||
|
Layout.maximumHeight: 24
|
||||||
|
Layout.alignment: Qt.AlignVCenter
|
||||||
|
Layout.fillWidth: false
|
||||||
|
name: qsTr("Redo")
|
||||||
|
source: "qrc:/gpt4all/icons/regenerate.svg"
|
||||||
|
onClicked: redoResponseDialog.open()
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatMessageButton {
|
||||||
|
Layout.maximumWidth: 24
|
||||||
|
Layout.maximumHeight: 24
|
||||||
|
Layout.alignment: Qt.AlignVCenter
|
||||||
|
Layout.fillWidth: false
|
||||||
|
name: qsTr("Copy")
|
||||||
|
source: "qrc:/gpt4all/icons/copy.svg"
|
||||||
|
onClicked: {
|
||||||
|
myTextArea.selectAll();
|
||||||
|
myTextArea.copy();
|
||||||
|
myTextArea.deselect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Item {
|
||||||
|
visible: name === "Response: " && MySettings.networkIsActive
|
||||||
|
Layout.alignment: Qt.AlignVCenter
|
||||||
|
Layout.preferredWidth: childrenRect.width
|
||||||
|
Layout.preferredHeight: childrenRect.height
|
||||||
|
Layout.fillWidth: false
|
||||||
|
|
||||||
|
ChatMessageButton {
|
||||||
|
id: thumbsUp
|
||||||
|
anchors.left: parent.left
|
||||||
|
anchors.verticalCenter: parent.verticalCenter
|
||||||
|
opacity: thumbsUpState || thumbsUpState == thumbsDownState ? 1.0 : 0.2
|
||||||
|
source: "qrc:/gpt4all/icons/thumbs_up.svg"
|
||||||
|
name: qsTr("Like response")
|
||||||
|
onClicked: {
|
||||||
|
if (thumbsUpState && !thumbsDownState)
|
||||||
|
return
|
||||||
|
|
||||||
|
chatModel.updateNewResponse(index, "")
|
||||||
|
chatModel.updateThumbsUpState(index, true)
|
||||||
|
chatModel.updateThumbsDownState(index, false)
|
||||||
|
Network.sendConversation(currentChat.id, getConversationJson());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatMessageButton {
|
||||||
|
id: thumbsDown
|
||||||
|
anchors.top: thumbsUp.top
|
||||||
|
anchors.topMargin: buttonRow.spacing
|
||||||
|
anchors.left: thumbsUp.right
|
||||||
|
anchors.leftMargin: buttonRow.spacing
|
||||||
|
checked: thumbsDownState
|
||||||
|
opacity: thumbsDownState || thumbsUpState == thumbsDownState ? 1.0 : 0.2
|
||||||
|
bgTransform: [
|
||||||
|
Matrix4x4 {
|
||||||
|
matrix: Qt.matrix4x4(-1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1)
|
||||||
|
},
|
||||||
|
Translate {
|
||||||
|
x: thumbsDown.width
|
||||||
|
}
|
||||||
|
]
|
||||||
|
source: "qrc:/gpt4all/icons/thumbs_down.svg"
|
||||||
|
name: qsTr("Dislike response")
|
||||||
|
onClicked: {
|
||||||
|
thumbsDownDialog.open()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // GridLayout
|
||||||
|
|
||||||
|
} // Item
|
||||||
|
|
||||||
|
GridLayout {
|
||||||
|
Layout.fillWidth: true
|
||||||
|
Layout.maximumWidth: parent.width
|
||||||
|
|
||||||
function shouldShowSuggestions() {
|
function shouldShowSuggestions() {
|
||||||
if (!currentResponse)
|
if (!isCurrentResponse)
|
||||||
return false;
|
return false;
|
||||||
if (MySettings.suggestionMode === 2) // Off
|
if (MySettings.suggestionMode === 2) // Off
|
||||||
return false;
|
return false;
|
||||||
@ -577,8 +669,8 @@ GridLayout {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Item {
|
Item {
|
||||||
visible: shouldShowSuggestions()
|
visible: parent.shouldShowSuggestions()
|
||||||
Layout.row: 4
|
Layout.row: 5
|
||||||
Layout.column: 0
|
Layout.column: 0
|
||||||
Layout.topMargin: 20
|
Layout.topMargin: 20
|
||||||
Layout.alignment: Qt.AlignVCenter | Qt.AlignRight
|
Layout.alignment: Qt.AlignVCenter | Qt.AlignRight
|
||||||
@ -601,8 +693,8 @@ GridLayout {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Item {
|
Item {
|
||||||
visible: shouldShowSuggestions()
|
visible: parent.shouldShowSuggestions()
|
||||||
Layout.row: 4
|
Layout.row: 5
|
||||||
Layout.column: 1
|
Layout.column: 1
|
||||||
Layout.topMargin: 20
|
Layout.topMargin: 20
|
||||||
Layout.fillWidth: true
|
Layout.fillWidth: true
|
||||||
@ -627,8 +719,8 @@ GridLayout {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ColumnLayout {
|
ColumnLayout {
|
||||||
visible: shouldShowSuggestions()
|
visible: parent.shouldShowSuggestions()
|
||||||
Layout.row: 5
|
Layout.row: 6
|
||||||
Layout.column: 1
|
Layout.column: 1
|
||||||
Layout.fillWidth: true
|
Layout.fillWidth: true
|
||||||
Layout.minimumHeight: 1
|
Layout.minimumHeight: 1
|
||||||
@ -786,4 +878,7 @@ GridLayout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
} // GridLayout
|
||||||
|
|
||||||
|
} // ColumnLayout
|
||||||
|
20
gpt4all-chat/qml/ChatMessageButton.qml
Normal file
20
gpt4all-chat/qml/ChatMessageButton.qml
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import QtQuick
|
||||||
|
import QtQuick.Controls
|
||||||
|
|
||||||
|
import gpt4all
|
||||||
|
|
||||||
|
MyToolButton {
|
||||||
|
property string name
|
||||||
|
|
||||||
|
width: 24
|
||||||
|
height: 24
|
||||||
|
imageWidth: width
|
||||||
|
imageHeight: height
|
||||||
|
ToolTip {
|
||||||
|
visible: parent.hovered
|
||||||
|
y: parent.height * 1.5
|
||||||
|
text: name
|
||||||
|
delay: Qt.styleHints.mousePressAndHoldInterval
|
||||||
|
}
|
||||||
|
Accessible.name: name
|
||||||
|
}
|
@ -24,6 +24,12 @@ Rectangle {
|
|||||||
|
|
||||||
property var currentChat: ChatListModel.currentChat
|
property var currentChat: ChatListModel.currentChat
|
||||||
property var chatModel: currentChat.chatModel
|
property var chatModel: currentChat.chatModel
|
||||||
|
property var currentModelInfo: currentChat && currentChat.modelInfo
|
||||||
|
property var currentModelId: null
|
||||||
|
onCurrentModelInfoChanged: {
|
||||||
|
const newId = currentModelInfo && currentModelInfo.id;
|
||||||
|
if (currentModelId !== newId) { currentModelId = newId; }
|
||||||
|
}
|
||||||
signal addCollectionViewRequested()
|
signal addCollectionViewRequested()
|
||||||
signal addModelViewRequested()
|
signal addModelViewRequested()
|
||||||
|
|
||||||
@ -79,14 +85,11 @@ Rectangle {
|
|||||||
function open_(msg) { message = msg; open(); }
|
function open_(msg) { message = msg; open(); }
|
||||||
}
|
}
|
||||||
|
|
||||||
SwitchModelDialog {
|
ConfirmationDialog {
|
||||||
id: switchModelDialog
|
id: switchModelDialog
|
||||||
anchors.centerIn: parent
|
property int index: -1
|
||||||
Item {
|
dialogTitle: qsTr("Erase conversation?")
|
||||||
Accessible.role: Accessible.Dialog
|
description: qsTr("Changing the model will erase the current conversation.")
|
||||||
Accessible.name: qsTr("Switch model dialog")
|
|
||||||
Accessible.description: qsTr("Warn the user if they switch models, then context will be erased")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PopupDialog {
|
PopupDialog {
|
||||||
@ -103,6 +106,16 @@ Rectangle {
|
|||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ConfirmationDialog {
|
||||||
|
id: resetContextDialog
|
||||||
|
dialogTitle: qsTr("Erase conversation?")
|
||||||
|
description: qsTr("The entire chat will be erased.")
|
||||||
|
onAccepted: {
|
||||||
|
Network.trackChatEvent("reset_context", { "length": chatModel.count });
|
||||||
|
currentChat.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function getConversation() {
|
function getConversation() {
|
||||||
var conversation = "";
|
var conversation = "";
|
||||||
for (var i = 0; i < chatModel.count; i++) {
|
for (var i = 0; i < chatModel.count; i++) {
|
||||||
@ -703,7 +716,7 @@ Rectangle {
|
|||||||
if (i !== -1) {
|
if (i !== -1) {
|
||||||
defaultModel = comboBox.valueAt(i);
|
defaultModel = comboBox.valueAt(i);
|
||||||
} else {
|
} else {
|
||||||
defaultModel = comboBox.valueAt(0);
|
defaultModel = comboBox.count ? comboBox.valueAt(0) : "";
|
||||||
}
|
}
|
||||||
if (defaultModel !== "") {
|
if (defaultModel !== "") {
|
||||||
defaultModelName = ModelList.modelInfo(defaultModel).name;
|
defaultModelName = ModelList.modelInfo(defaultModel).name;
|
||||||
@ -790,9 +803,9 @@ Rectangle {
|
|||||||
Layout.leftMargin: 50
|
Layout.leftMargin: 50
|
||||||
Layout.rightMargin: 50
|
Layout.rightMargin: 50
|
||||||
Layout.alignment: Qt.AlignHCenter
|
Layout.alignment: Qt.AlignHCenter
|
||||||
spacing: 25
|
spacing: 10
|
||||||
model: chatModel
|
model: chatModel
|
||||||
cacheBuffer: Math.max(0, listView.contentHeight)
|
cacheBuffer: 2147483647
|
||||||
|
|
||||||
ScrollBar.vertical: ScrollBar {
|
ScrollBar.vertical: ScrollBar {
|
||||||
policy: ScrollBar.AsNeeded
|
policy: ScrollBar.AsNeeded
|
||||||
@ -804,6 +817,12 @@ Rectangle {
|
|||||||
|
|
||||||
delegate: ChatItemView {
|
delegate: ChatItemView {
|
||||||
width: listView.contentItem.width - 15
|
width: listView.contentItem.width - 15
|
||||||
|
inputBoxText: textInput.text
|
||||||
|
onSetInputBoxText: text => {
|
||||||
|
textInput.text = text;
|
||||||
|
textInput.forceActiveFocus();
|
||||||
|
textInput.cursorPosition = text.length;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function scrollToEnd() {
|
function scrollToEnd() {
|
||||||
@ -832,11 +851,9 @@ Rectangle {
|
|||||||
clip: true
|
clip: true
|
||||||
z: 400
|
z: 400
|
||||||
|
|
||||||
property bool isHovered: {
|
property bool isHovered: (
|
||||||
return conversationTrayButton.isHovered ||
|
conversationTrayButton.isHovered || resetContextButton.hovered || copyChatButton.hovered
|
||||||
resetContextButton.hovered || copyChatButton.hovered ||
|
)
|
||||||
regenerateButton.hovered
|
|
||||||
}
|
|
||||||
|
|
||||||
state: conversationTrayContent.isHovered ? "expanded" : "collapsed"
|
state: conversationTrayContent.isHovered ? "expanded" : "collapsed"
|
||||||
states: [
|
states: [
|
||||||
@ -892,11 +909,7 @@ Rectangle {
|
|||||||
source: "qrc:/gpt4all/icons/recycle.svg"
|
source: "qrc:/gpt4all/icons/recycle.svg"
|
||||||
imageWidth: 20
|
imageWidth: 20
|
||||||
imageHeight: 20
|
imageHeight: 20
|
||||||
onClicked: {
|
onClicked: resetContextDialog.open()
|
||||||
Network.trackChatEvent("reset_context", { "length": chatModel.count })
|
|
||||||
currentChat.reset();
|
|
||||||
currentChat.processSystemPrompt();
|
|
||||||
}
|
|
||||||
ToolTip.visible: resetContextButton.hovered
|
ToolTip.visible: resetContextButton.hovered
|
||||||
ToolTip.text: qsTr("Erase and reset chat session")
|
ToolTip.text: qsTr("Erase and reset chat session")
|
||||||
}
|
}
|
||||||
@ -921,34 +934,6 @@ Rectangle {
|
|||||||
ToolTip.visible: copyChatButton.hovered
|
ToolTip.visible: copyChatButton.hovered
|
||||||
ToolTip.text: qsTr("Copy chat session to clipboard")
|
ToolTip.text: qsTr("Copy chat session to clipboard")
|
||||||
}
|
}
|
||||||
MyToolButton {
|
|
||||||
id: regenerateButton
|
|
||||||
Layout.preferredWidth: 40
|
|
||||||
Layout.preferredHeight: 40
|
|
||||||
source: "qrc:/gpt4all/icons/regenerate.svg"
|
|
||||||
imageWidth: 20
|
|
||||||
imageHeight: 20
|
|
||||||
visible: chatModel.count && !currentChat.isServer && currentChat.isModelLoaded && !currentChat.responseInProgress
|
|
||||||
onClicked: {
|
|
||||||
if (chatModel.count < 2)
|
|
||||||
return
|
|
||||||
var promptIndex = chatModel.count - 2
|
|
||||||
var promptElement = chatModel.get(promptIndex)
|
|
||||||
var responseIndex = chatModel.count - 1
|
|
||||||
var responseElement = chatModel.get(responseIndex)
|
|
||||||
if (promptElement.name !== "Prompt: " || responseElement.name !== "Response: ")
|
|
||||||
return
|
|
||||||
currentChat.regenerateResponse()
|
|
||||||
chatModel.updateCurrentResponse(responseIndex, true)
|
|
||||||
chatModel.updateStopped(responseIndex, false)
|
|
||||||
chatModel.updateThumbsUpState(responseIndex, false)
|
|
||||||
chatModel.updateThumbsDownState(responseIndex, false)
|
|
||||||
chatModel.updateNewResponse(responseIndex, "")
|
|
||||||
currentChat.prompt(promptElement.promptPlusAttachments)
|
|
||||||
}
|
|
||||||
ToolTip.visible: regenerateButton.hovered
|
|
||||||
ToolTip.text: qsTr("Redo last chat response")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1026,13 +1011,15 @@ Rectangle {
|
|||||||
anchors.leftMargin: 30
|
anchors.leftMargin: 30
|
||||||
horizontalAlignment: Qt.AlignRight
|
horizontalAlignment: Qt.AlignRight
|
||||||
verticalAlignment: Qt.AlignVCenter
|
verticalAlignment: Qt.AlignVCenter
|
||||||
color: theme.mutedTextColor
|
color: textInputView.error !== null ? theme.textErrorColor : theme.mutedTextColor
|
||||||
visible: currentChat.tokenSpeed !== "" || externalHoveredLink !== ""
|
visible: currentChat.tokenSpeed !== "" || externalHoveredLink !== "" || textInputView.error !== null
|
||||||
elide: Text.ElideRight
|
elide: Text.ElideRight
|
||||||
wrapMode: Text.WordWrap
|
wrapMode: Text.WordWrap
|
||||||
text: {
|
text: {
|
||||||
if (externalHoveredLink !== "")
|
if (externalHoveredLink !== "")
|
||||||
return externalHoveredLink
|
return externalHoveredLink
|
||||||
|
if (textInputView.error !== null)
|
||||||
|
return textInputView.error;
|
||||||
|
|
||||||
const segments = [currentChat.tokenSpeed];
|
const segments = [currentChat.tokenSpeed];
|
||||||
const device = currentChat.device;
|
const device = currentChat.device;
|
||||||
@ -1050,6 +1037,7 @@ Rectangle {
|
|||||||
}
|
}
|
||||||
font.pixelSize: theme.fontSizeSmaller
|
font.pixelSize: theme.fontSizeSmaller
|
||||||
font.bold: true
|
font.bold: true
|
||||||
|
onLinkActivated: function(link) { Qt.openUrlExternally(link) }
|
||||||
}
|
}
|
||||||
|
|
||||||
RectangularGlow {
|
RectangularGlow {
|
||||||
@ -1079,8 +1067,8 @@ Rectangle {
|
|||||||
Rectangle {
|
Rectangle {
|
||||||
id: textInputView
|
id: textInputView
|
||||||
color: theme.controlBackground
|
color: theme.controlBackground
|
||||||
border.width: 1
|
border.width: error === null ? 1 : 2
|
||||||
border.color: theme.controlBorder
|
border.color: error === null ? theme.controlBorder : theme.textErrorColor
|
||||||
radius: 10
|
radius: 10
|
||||||
anchors.left: parent.left
|
anchors.left: parent.left
|
||||||
anchors.right: parent.right
|
anchors.right: parent.right
|
||||||
@ -1091,6 +1079,41 @@ Rectangle {
|
|||||||
height: textInputViewLayout.implicitHeight
|
height: textInputViewLayout.implicitHeight
|
||||||
visible: !currentChat.isServer && ModelList.selectableModels.count !== 0
|
visible: !currentChat.isServer && ModelList.selectableModels.count !== 0
|
||||||
|
|
||||||
|
property var error: null
|
||||||
|
function checkError() {
|
||||||
|
const info = currentModelInfo;
|
||||||
|
if (info === null || !info.id) {
|
||||||
|
error = null;
|
||||||
|
} else if (info.chatTemplate.isLegacy) {
|
||||||
|
error = qsTr("Legacy prompt template needs to be " +
|
||||||
|
"<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">updated" +
|
||||||
|
"</a> in Settings.");
|
||||||
|
} else if (!info.chatTemplate.isSet) {
|
||||||
|
error = qsTr("No <a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
|
||||||
|
"chat template</a> configured.");
|
||||||
|
} else if (/^\s*$/.test(info.chatTemplate.value)) {
|
||||||
|
error = qsTr("The <a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
|
||||||
|
"chat template</a> cannot be blank.");
|
||||||
|
} else if (info.systemMessage.isLegacy) {
|
||||||
|
error = qsTr("Legacy system prompt needs to be " +
|
||||||
|
"<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">updated" +
|
||||||
|
"</a> in Settings.");
|
||||||
|
} else
|
||||||
|
error = null;
|
||||||
|
}
|
||||||
|
Component.onCompleted: checkError()
|
||||||
|
Connections {
|
||||||
|
target: window
|
||||||
|
function onCurrentModelIdChanged() { textInputView.checkError(); }
|
||||||
|
}
|
||||||
|
Connections {
|
||||||
|
target: MySettings
|
||||||
|
function onChatTemplateChanged(info)
|
||||||
|
{ if (info.id === window.currentModelId) textInputView.checkError(); }
|
||||||
|
function onSystemMessageChanged(info)
|
||||||
|
{ if (info.id === window.currentModelId) textInputView.checkError(); }
|
||||||
|
}
|
||||||
|
|
||||||
MouseArea {
|
MouseArea {
|
||||||
id: textInputViewMouseArea
|
id: textInputViewMouseArea
|
||||||
anchors.fill: parent
|
anchors.fill: parent
|
||||||
@ -1214,16 +1237,16 @@ Rectangle {
|
|||||||
Accessible.role: Accessible.EditableText
|
Accessible.role: Accessible.EditableText
|
||||||
Accessible.name: placeholderText
|
Accessible.name: placeholderText
|
||||||
Accessible.description: qsTr("Send messages/prompts to the model")
|
Accessible.description: qsTr("Send messages/prompts to the model")
|
||||||
Keys.onReturnPressed: (event)=> {
|
Keys.onReturnPressed: event => {
|
||||||
if (event.modifiers & Qt.ControlModifier || event.modifiers & Qt.ShiftModifier)
|
if (event.modifiers & Qt.ControlModifier || event.modifiers & Qt.ShiftModifier) {
|
||||||
event.accepted = false;
|
event.accepted = false;
|
||||||
else {
|
} else if (!chatModel.hasError && textInputView.error === null) {
|
||||||
editingFinished();
|
editingFinished();
|
||||||
sendMessage()
|
sendMessage();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
function sendMessage() {
|
function sendMessage() {
|
||||||
if ((textInput.text === "" && attachmentModel.count === 0) || currentChat.responseInProgress || currentChat.restoringFromText)
|
if ((textInput.text === "" && attachmentModel.count === 0) || currentChat.responseInProgress)
|
||||||
return
|
return
|
||||||
|
|
||||||
currentChat.stopGenerating()
|
currentChat.stopGenerating()
|
||||||
@ -1338,6 +1361,7 @@ Rectangle {
|
|||||||
imageWidth: theme.fontSizeLargest
|
imageWidth: theme.fontSizeLargest
|
||||||
imageHeight: theme.fontSizeLargest
|
imageHeight: theme.fontSizeLargest
|
||||||
visible: !currentChat.responseInProgress && !currentChat.isServer && ModelList.selectableModels.count !== 0
|
visible: !currentChat.responseInProgress && !currentChat.isServer && ModelList.selectableModels.count !== 0
|
||||||
|
enabled: !chatModel.hasError && textInputView.error === null
|
||||||
source: "qrc:/gpt4all/icons/send_message.svg"
|
source: "qrc:/gpt4all/icons/send_message.svg"
|
||||||
Accessible.name: qsTr("Send message")
|
Accessible.name: qsTr("Send message")
|
||||||
Accessible.description: qsTr("Sends the message/prompt contained in textfield to the model")
|
Accessible.description: qsTr("Sends the message/prompt contained in textfield to the model")
|
||||||
|
59
gpt4all-chat/qml/ConfirmationDialog.qml
Normal file
59
gpt4all-chat/qml/ConfirmationDialog.qml
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import QtCore
|
||||||
|
import QtQuick
|
||||||
|
import QtQuick.Controls
|
||||||
|
import QtQuick.Controls.Basic
|
||||||
|
import QtQuick.Layouts
|
||||||
|
|
||||||
|
MyDialog {
|
||||||
|
id: confirmationDialog
|
||||||
|
anchors.centerIn: parent
|
||||||
|
modal: true
|
||||||
|
padding: 20
|
||||||
|
property alias dialogTitle: titleText.text
|
||||||
|
property alias description: descriptionText.text
|
||||||
|
|
||||||
|
Theme { id: theme }
|
||||||
|
|
||||||
|
contentItem: ColumnLayout {
|
||||||
|
Text {
|
||||||
|
id: titleText
|
||||||
|
Layout.alignment: Qt.AlignHCenter
|
||||||
|
textFormat: Text.StyledText
|
||||||
|
color: theme.textColor
|
||||||
|
font.pixelSize: theme.fontSizeLarger
|
||||||
|
font.bold: true
|
||||||
|
}
|
||||||
|
|
||||||
|
Text {
|
||||||
|
id: descriptionText
|
||||||
|
Layout.alignment: Qt.AlignHCenter
|
||||||
|
textFormat: Text.StyledText
|
||||||
|
color: theme.textColor
|
||||||
|
font.pixelSize: theme.fontSizeMedium
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
footer: DialogButtonBox {
|
||||||
|
id: dialogBox
|
||||||
|
padding: 20
|
||||||
|
alignment: Qt.AlignRight
|
||||||
|
spacing: 10
|
||||||
|
MySettingsButton {
|
||||||
|
text: qsTr("OK")
|
||||||
|
textColor: theme.mediumButtonText
|
||||||
|
backgroundColor: theme.mediumButtonBackground
|
||||||
|
backgroundColorHovered: theme.mediumButtonBackgroundHovered
|
||||||
|
DialogButtonBox.buttonRole: DialogButtonBox.AcceptRole
|
||||||
|
}
|
||||||
|
MySettingsButton {
|
||||||
|
text: qsTr("Cancel")
|
||||||
|
DialogButtonBox.buttonRole: DialogButtonBox.RejectRole
|
||||||
|
}
|
||||||
|
background: Rectangle {
|
||||||
|
color: "transparent"
|
||||||
|
}
|
||||||
|
Keys.onEnterPressed: confirmationDialog.accept()
|
||||||
|
Keys.onReturnPressed: confirmationDialog.accept()
|
||||||
|
}
|
||||||
|
Component.onCompleted: dialogBox.forceActiveFocus()
|
||||||
|
}
|
@ -10,7 +10,7 @@ import mysettings
|
|||||||
import network
|
import network
|
||||||
|
|
||||||
MySettingsTab {
|
MySettingsTab {
|
||||||
onRestoreDefaultsClicked: {
|
onRestoreDefaults: {
|
||||||
MySettings.restoreLocalDocsDefaults();
|
MySettings.restoreLocalDocsDefaults();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,10 +8,34 @@ import mysettings
|
|||||||
import chatlistmodel
|
import chatlistmodel
|
||||||
|
|
||||||
MySettingsTab {
|
MySettingsTab {
|
||||||
onRestoreDefaultsClicked: {
|
onRestoreDefaults: {
|
||||||
MySettings.restoreModelDefaults(root.currentModelInfo);
|
MySettings.restoreModelDefaults(root.currentModelInfo);
|
||||||
}
|
}
|
||||||
title: qsTr("Model")
|
title: qsTr("Model")
|
||||||
|
|
||||||
|
ConfirmationDialog {
|
||||||
|
id: resetSystemMessageDialog
|
||||||
|
property var index: null
|
||||||
|
property bool resetClears: false
|
||||||
|
dialogTitle: qsTr("%1 system message?").arg(resetClears ? qsTr("Clear") : qsTr("Reset"))
|
||||||
|
description: qsTr("The system message will be %1.").arg(resetClears ? qsTr("removed") : qsTr("reset to the default"))
|
||||||
|
onAccepted: MySettings.resetModelSystemMessage(ModelList.modelInfo(index))
|
||||||
|
function show(index_, resetClears_) { index = index_; resetClears = resetClears_; open(); }
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfirmationDialog {
|
||||||
|
id: resetChatTemplateDialog
|
||||||
|
property bool resetClears: false
|
||||||
|
property var index: null
|
||||||
|
dialogTitle: qsTr("%1 chat template?").arg(resetClears ? qsTr("Clear") : qsTr("Reset"))
|
||||||
|
description: qsTr("The chat template will be %1.").arg(resetClears ? qsTr("erased") : qsTr("reset to the default"))
|
||||||
|
onAccepted: {
|
||||||
|
MySettings.resetModelChatTemplate(ModelList.modelInfo(index));
|
||||||
|
templateTextArea.resetText();
|
||||||
|
}
|
||||||
|
function show(index_, resetClears_) { index = index_; resetClears = resetClears_; open(); }
|
||||||
|
}
|
||||||
|
|
||||||
contentItem: GridLayout {
|
contentItem: GridLayout {
|
||||||
id: root
|
id: root
|
||||||
columns: 3
|
columns: 3
|
||||||
@ -35,6 +59,7 @@ MySettingsTab {
|
|||||||
|
|
||||||
RowLayout {
|
RowLayout {
|
||||||
Layout.fillWidth: true
|
Layout.fillWidth: true
|
||||||
|
Layout.maximumWidth: parent.width
|
||||||
Layout.row: 2
|
Layout.row: 2
|
||||||
Layout.column: 0
|
Layout.column: 0
|
||||||
Layout.columnSpan: 2
|
Layout.columnSpan: 2
|
||||||
@ -153,69 +178,154 @@ MySettingsTab {
|
|||||||
Layout.fillWidth: true
|
Layout.fillWidth: true
|
||||||
}
|
}
|
||||||
|
|
||||||
MySettingsLabel {
|
RowLayout {
|
||||||
visible: !root.currentModelInfo.isOnline
|
|
||||||
text: qsTr("System Prompt")
|
|
||||||
helpText: qsTr("Prefixed at the beginning of every conversation. Must contain the appropriate framing tokens.")
|
|
||||||
Layout.row: 7
|
Layout.row: 7
|
||||||
Layout.column: 0
|
Layout.columnSpan: 2
|
||||||
Layout.topMargin: 15
|
Layout.topMargin: 15
|
||||||
|
Layout.fillWidth: true
|
||||||
|
Layout.maximumWidth: parent.width
|
||||||
|
spacing: 10
|
||||||
|
MySettingsLabel {
|
||||||
|
id: systemMessageLabel
|
||||||
|
text: qsTr("System Message")
|
||||||
|
helpText: qsTr("A message to set the context or guide the behavior of the model. Leave blank for " +
|
||||||
|
"none. NOTE: Since GPT4All 3.5, this should not contain control tokens.")
|
||||||
|
onReset: () => resetSystemMessageDialog.show(root.currentModelId, resetClears)
|
||||||
|
function updateResetButton() {
|
||||||
|
const info = root.currentModelInfo;
|
||||||
|
// NOTE: checks if the *override* is set, regardless of whether there is a default
|
||||||
|
canReset = !!info.id && MySettings.isModelSystemMessageSet(info);
|
||||||
|
resetClears = !info.defaultSystemMessage;
|
||||||
|
}
|
||||||
|
Component.onCompleted: updateResetButton()
|
||||||
|
Connections {
|
||||||
|
target: root
|
||||||
|
function onCurrentModelIdChanged() { systemMessageLabel.updateResetButton(); }
|
||||||
|
}
|
||||||
|
Connections {
|
||||||
|
target: MySettings
|
||||||
|
function onSystemMessageChanged(info)
|
||||||
|
{ if (info.id === root.currentModelId) systemMessageLabel.updateResetButton(); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Label {
|
||||||
|
id: systemMessageLabelHelp
|
||||||
|
visible: systemMessageArea.errState !== "ok"
|
||||||
|
Layout.alignment: Qt.AlignBottom
|
||||||
|
Layout.fillWidth: true
|
||||||
|
Layout.rightMargin: 5
|
||||||
|
Layout.maximumHeight: systemMessageLabel.height
|
||||||
|
text: qsTr("System message is not " +
|
||||||
|
"<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">plain text</a>.")
|
||||||
|
color: systemMessageArea.errState === "error" ? theme.textErrorColor : theme.textWarningColor
|
||||||
|
font.pixelSize: theme.fontSizeLarger
|
||||||
|
font.bold: true
|
||||||
|
wrapMode: Text.Wrap
|
||||||
|
elide: Text.ElideRight
|
||||||
|
onLinkActivated: function(link) { Qt.openUrlExternally(link) }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Rectangle {
|
Rectangle {
|
||||||
id: systemPrompt
|
id: systemMessage
|
||||||
visible: !root.currentModelInfo.isOnline
|
|
||||||
Layout.row: 8
|
Layout.row: 8
|
||||||
Layout.column: 0
|
Layout.column: 0
|
||||||
Layout.columnSpan: 2
|
Layout.columnSpan: 2
|
||||||
Layout.fillWidth: true
|
Layout.fillWidth: true
|
||||||
color: "transparent"
|
color: "transparent"
|
||||||
Layout.minimumHeight: Math.max(100, systemPromptArea.contentHeight + 20)
|
Layout.minimumHeight: Math.max(100, systemMessageArea.contentHeight + 20)
|
||||||
MyTextArea {
|
MyTextArea {
|
||||||
id: systemPromptArea
|
id: systemMessageArea
|
||||||
anchors.fill: parent
|
anchors.fill: parent
|
||||||
text: root.currentModelInfo.systemPrompt
|
property bool isBeingReset: false
|
||||||
|
function resetText() {
|
||||||
|
const info = root.currentModelInfo;
|
||||||
|
isBeingReset = true;
|
||||||
|
text = (info.id ? info.systemMessage.value : null) ?? "";
|
||||||
|
isBeingReset = false;
|
||||||
|
}
|
||||||
|
Component.onCompleted: resetText()
|
||||||
Connections {
|
Connections {
|
||||||
target: MySettings
|
target: MySettings
|
||||||
function onSystemPromptChanged() {
|
function onSystemMessageChanged(info)
|
||||||
systemPromptArea.text = root.currentModelInfo.systemPrompt;
|
{ if (info.id === root.currentModelId) systemMessageArea.resetText(); }
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Connections {
|
Connections {
|
||||||
target: root
|
target: root
|
||||||
function onCurrentModelInfoChanged() {
|
function onCurrentModelIdChanged() { systemMessageArea.resetText(); }
|
||||||
systemPromptArea.text = root.currentModelInfo.systemPrompt;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
// strict validation, because setModelSystemMessage clears isLegacy
|
||||||
|
readonly property var reLegacyCheck: (
|
||||||
|
/(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>/m
|
||||||
|
)
|
||||||
onTextChanged: {
|
onTextChanged: {
|
||||||
MySettings.setModelSystemPrompt(root.currentModelInfo, text)
|
const info = root.currentModelInfo;
|
||||||
|
if (!info.id) {
|
||||||
|
errState = "ok";
|
||||||
|
} else if (info.systemMessage.isLegacy && (isBeingReset || reLegacyCheck.test(text))) {
|
||||||
|
errState = "error";
|
||||||
|
} else
|
||||||
|
errState = reLegacyCheck.test(text) ? "warning" : "ok";
|
||||||
|
if (info.id && errState !== "error" && !isBeingReset)
|
||||||
|
MySettings.setModelSystemMessage(info, text);
|
||||||
|
systemMessageLabel.updateResetButton();
|
||||||
}
|
}
|
||||||
Accessible.role: Accessible.EditableText
|
Accessible.role: Accessible.EditableText
|
||||||
|
Accessible.name: systemMessageLabel.text
|
||||||
|
Accessible.description: systemMessageLabelHelp.text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RowLayout {
|
RowLayout {
|
||||||
Layout.row: 9
|
Layout.row: 9
|
||||||
Layout.column: 0
|
|
||||||
Layout.columnSpan: 2
|
Layout.columnSpan: 2
|
||||||
Layout.topMargin: 15
|
Layout.topMargin: 15
|
||||||
|
Layout.fillWidth: true
|
||||||
|
Layout.maximumWidth: parent.width
|
||||||
spacing: 10
|
spacing: 10
|
||||||
MySettingsLabel {
|
MySettingsLabel {
|
||||||
id: promptTemplateLabel
|
id: chatTemplateLabel
|
||||||
text: qsTr("Prompt Template")
|
text: qsTr("Chat Template")
|
||||||
helpText: qsTr("The template that wraps every prompt.")
|
helpText: qsTr("This Jinja template turns the chat into input for the model.")
|
||||||
|
onReset: () => resetChatTemplateDialog.show(root.currentModelId, resetClears)
|
||||||
|
function updateResetButton() {
|
||||||
|
const info = root.currentModelInfo;
|
||||||
|
canReset = !!info.id && (
|
||||||
|
MySettings.isModelChatTemplateSet(info)
|
||||||
|
|| templateTextArea.text !== (info.chatTemplate.value ?? "")
|
||||||
|
);
|
||||||
|
resetClears = !info.defaultChatTemplate;
|
||||||
|
}
|
||||||
|
Component.onCompleted: updateResetButton()
|
||||||
|
Connections {
|
||||||
|
target: root
|
||||||
|
function onCurrentModelIdChanged() { chatTemplateLabel.updateResetButton(); }
|
||||||
|
}
|
||||||
|
Connections {
|
||||||
|
target: MySettings
|
||||||
|
function onChatTemplateChanged(info)
|
||||||
|
{ if (info.id === root.currentModelId) chatTemplateLabel.updateResetButton(); }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
MySettingsLabel {
|
Label {
|
||||||
id: promptTemplateLabelHelp
|
id: chatTemplateLabelHelp
|
||||||
text: qsTr("Must contain the string \"%1\" to be replaced with the user's input.")
|
visible: templateTextArea.errState !== "ok"
|
||||||
color: theme.textErrorColor
|
Layout.alignment: Qt.AlignBottom
|
||||||
visible: templateTextArea.text.indexOf("%1") === -1
|
Layout.fillWidth: true
|
||||||
wrapMode: TextArea.Wrap
|
Layout.rightMargin: 5
|
||||||
|
Layout.maximumHeight: chatTemplateLabel.height
|
||||||
|
text: templateTextArea.errMsg
|
||||||
|
color: templateTextArea.errState === "error" ? theme.textErrorColor : theme.textWarningColor
|
||||||
|
font.pixelSize: theme.fontSizeLarger
|
||||||
|
font.bold: true
|
||||||
|
wrapMode: Text.Wrap
|
||||||
|
elide: Text.ElideRight
|
||||||
|
onLinkActivated: function(link) { Qt.openUrlExternally(link) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Rectangle {
|
Rectangle {
|
||||||
id: promptTemplate
|
id: chatTemplate
|
||||||
Layout.row: 10
|
Layout.row: 10
|
||||||
Layout.column: 0
|
Layout.column: 0
|
||||||
Layout.columnSpan: 2
|
Layout.columnSpan: 2
|
||||||
@ -226,27 +336,71 @@ MySettingsTab {
|
|||||||
MyTextArea {
|
MyTextArea {
|
||||||
id: templateTextArea
|
id: templateTextArea
|
||||||
anchors.fill: parent
|
anchors.fill: parent
|
||||||
text: root.currentModelInfo.promptTemplate
|
font: fixedFont
|
||||||
|
property bool isBeingReset: false
|
||||||
|
property var errMsg: null
|
||||||
|
function resetText() {
|
||||||
|
const info = root.currentModelInfo;
|
||||||
|
isBeingReset = true;
|
||||||
|
text = (info.id ? info.chatTemplate.value : null) ?? "";
|
||||||
|
isBeingReset = false;
|
||||||
|
}
|
||||||
|
Component.onCompleted: resetText()
|
||||||
Connections {
|
Connections {
|
||||||
target: MySettings
|
target: MySettings
|
||||||
function onPromptTemplateChanged() {
|
function onChatTemplateChanged() { templateTextArea.resetText(); }
|
||||||
templateTextArea.text = root.currentModelInfo.promptTemplate;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Connections {
|
Connections {
|
||||||
target: root
|
target: root
|
||||||
function onCurrentModelInfoChanged() {
|
function onCurrentModelIdChanged() { templateTextArea.resetText(); }
|
||||||
templateTextArea.text = root.currentModelInfo.promptTemplate;
|
}
|
||||||
}
|
function legacyCheck() {
|
||||||
|
return /%[12]\b/.test(text) || !/\{%.*%\}.*\{\{.*\}\}.*\{%.*%\}/.test(text.replace(/\n/g, ''))
|
||||||
|
|| !/\bcontent\b/.test(text);
|
||||||
}
|
}
|
||||||
onTextChanged: {
|
onTextChanged: {
|
||||||
if (templateTextArea.text.indexOf("%1") !== -1) {
|
const info = root.currentModelInfo;
|
||||||
MySettings.setModelPromptTemplate(root.currentModelInfo, text)
|
let jinjaError;
|
||||||
|
if (!info.id) {
|
||||||
|
errMsg = null;
|
||||||
|
errState = "ok";
|
||||||
|
} else if (info.chatTemplate.isLegacy && (isBeingReset || legacyCheck())) {
|
||||||
|
errMsg = null;
|
||||||
|
errState = "error";
|
||||||
|
} else if (text === "" && !info.chatTemplate.isSet) {
|
||||||
|
errMsg = qsTr("No <a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
|
||||||
|
"chat template</a> configured.");
|
||||||
|
errState = "error";
|
||||||
|
} else if (/^\s*$/.test(text)) {
|
||||||
|
errMsg = qsTr("The <a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
|
||||||
|
"chat template</a> cannot be blank.");
|
||||||
|
errState = "error";
|
||||||
|
} else if ((jinjaError = MySettings.checkJinjaTemplateError(text)) !== null) {
|
||||||
|
errMsg = qsTr("<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">Syntax" +
|
||||||
|
" error</a>: %1").arg(jinjaError);
|
||||||
|
errState = "error";
|
||||||
|
} else if (legacyCheck()) {
|
||||||
|
errMsg = qsTr("Chat template is not in " +
|
||||||
|
"<a href=\"https://docs.gpt4all.io/gpt4all_desktop/chat_templates.html\">" +
|
||||||
|
"Jinja format</a>.")
|
||||||
|
errState = "warning";
|
||||||
|
} else {
|
||||||
|
errState = "ok";
|
||||||
|
}
|
||||||
|
if (info.id && errState !== "error" && !isBeingReset)
|
||||||
|
MySettings.setModelChatTemplate(info, text);
|
||||||
|
chatTemplateLabel.updateResetButton();
|
||||||
|
}
|
||||||
|
Keys.onPressed: event => {
|
||||||
|
if (event.key === Qt.Key_Tab) {
|
||||||
|
const a = templateTextArea;
|
||||||
|
event.accepted = true; // suppress tab
|
||||||
|
a.insert(a.cursorPosition, ' '); // four spaces
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Accessible.role: Accessible.EditableText
|
Accessible.role: Accessible.EditableText
|
||||||
Accessible.name: promptTemplateLabel.text
|
Accessible.name: chatTemplateLabel.text
|
||||||
Accessible.description: promptTemplateLabelHelp.text
|
Accessible.description: chatTemplateLabelHelp.text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ Button {
|
|||||||
property color borderColor: "transparent"
|
property color borderColor: "transparent"
|
||||||
property real fontPixelSize: theme.fontSizeLarge
|
property real fontPixelSize: theme.fontSizeLarge
|
||||||
property string toolTip
|
property string toolTip
|
||||||
|
property alias backgroundRadius: background.radius
|
||||||
|
|
||||||
contentItem: Text {
|
contentItem: Text {
|
||||||
text: myButton.text
|
text: myButton.text
|
||||||
@ -28,6 +29,7 @@ Button {
|
|||||||
Accessible.name: text
|
Accessible.name: text
|
||||||
}
|
}
|
||||||
background: Rectangle {
|
background: Rectangle {
|
||||||
|
id: background
|
||||||
radius: 10
|
radius: 10
|
||||||
border.width: borderWidth
|
border.width: borderWidth
|
||||||
border.color: borderColor
|
border.color: borderColor
|
||||||
|
@ -17,13 +17,42 @@ ColumnLayout {
|
|||||||
property alias color: mainTextLabel.color
|
property alias color: mainTextLabel.color
|
||||||
property alias linkColor: mainTextLabel.linkColor
|
property alias linkColor: mainTextLabel.linkColor
|
||||||
|
|
||||||
Label {
|
property var onReset: null
|
||||||
id: mainTextLabel
|
property alias canReset: resetButton.enabled
|
||||||
color: theme.settingsTitleTextColor
|
property bool resetClears: false
|
||||||
font.pixelSize: theme.fontSizeLarger
|
|
||||||
font.bold: true
|
Item {
|
||||||
onLinkActivated: function(link) {
|
anchors.margins: 5
|
||||||
root.linkActivated(link);
|
width: childrenRect.width
|
||||||
|
height: mainTextLabel.contentHeight
|
||||||
|
|
||||||
|
Label {
|
||||||
|
id: mainTextLabel
|
||||||
|
anchors.left: parent.left
|
||||||
|
anchors.top: parent.top
|
||||||
|
anchors.bottom: parent.bottom
|
||||||
|
color: theme.settingsTitleTextColor
|
||||||
|
font.pixelSize: theme.fontSizeLarger
|
||||||
|
font.bold: true
|
||||||
|
verticalAlignment: Text.AlignVCenter
|
||||||
|
onLinkActivated: function(link) {
|
||||||
|
root.linkActivated(link);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MySettingsButton {
|
||||||
|
id: resetButton
|
||||||
|
anchors.baseline: mainTextLabel.baseline
|
||||||
|
anchors.left: mainTextLabel.right
|
||||||
|
height: mainTextLabel.contentHeight
|
||||||
|
anchors.leftMargin: 10
|
||||||
|
padding: 2
|
||||||
|
leftPadding: 10
|
||||||
|
rightPadding: 10
|
||||||
|
backgroundRadius: 5
|
||||||
|
text: resetClears ? qsTr("Clear") : qsTr("Reset")
|
||||||
|
visible: root.onReset !== null
|
||||||
|
onClicked: root.onReset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Label {
|
Label {
|
||||||
|
@ -9,7 +9,7 @@ Item {
|
|||||||
property string title: ""
|
property string title: ""
|
||||||
property Item contentItem: null
|
property Item contentItem: null
|
||||||
property bool showRestoreDefaultsButton: true
|
property bool showRestoreDefaultsButton: true
|
||||||
signal restoreDefaultsClicked
|
signal restoreDefaults
|
||||||
|
|
||||||
onContentItemChanged: function() {
|
onContentItemChanged: function() {
|
||||||
if (contentItem) {
|
if (contentItem) {
|
||||||
@ -19,6 +19,13 @@ Item {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ConfirmationDialog {
|
||||||
|
id: restoreDefaultsDialog
|
||||||
|
dialogTitle: qsTr("Restore defaults?")
|
||||||
|
description: qsTr("This page of settings will be reset to the defaults.")
|
||||||
|
onAccepted: root.restoreDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
ScrollView {
|
ScrollView {
|
||||||
id: scrollView
|
id: scrollView
|
||||||
width: parent.width
|
width: parent.width
|
||||||
@ -47,6 +54,7 @@ Item {
|
|||||||
Column {
|
Column {
|
||||||
id: contentInner
|
id: contentInner
|
||||||
Layout.fillWidth: true
|
Layout.fillWidth: true
|
||||||
|
Layout.maximumWidth: parent.width
|
||||||
}
|
}
|
||||||
|
|
||||||
Item {
|
Item {
|
||||||
@ -63,9 +71,7 @@ Item {
|
|||||||
Accessible.role: Accessible.Button
|
Accessible.role: Accessible.Button
|
||||||
Accessible.name: text
|
Accessible.name: text
|
||||||
Accessible.description: qsTr("Restores settings dialog to a default state")
|
Accessible.description: qsTr("Restores settings dialog to a default state")
|
||||||
onClicked: {
|
onClicked: restoreDefaultsDialog.open()
|
||||||
root.restoreDefaultsClicked();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,18 +5,27 @@ import QtQuick.Controls.Basic
|
|||||||
|
|
||||||
TextArea {
|
TextArea {
|
||||||
id: myTextArea
|
id: myTextArea
|
||||||
|
|
||||||
|
property string errState: "ok" // one of "ok", "error", "warning"
|
||||||
|
|
||||||
color: enabled ? theme.textColor : theme.mutedTextColor
|
color: enabled ? theme.textColor : theme.mutedTextColor
|
||||||
placeholderTextColor: theme.mutedTextColor
|
placeholderTextColor: theme.mutedTextColor
|
||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
background: Rectangle {
|
background: Rectangle {
|
||||||
implicitWidth: 150
|
implicitWidth: 150
|
||||||
color: theme.controlBackground
|
color: theme.controlBackground
|
||||||
border.width: 1
|
border.width: errState === "ok" ? 1 : 2
|
||||||
border.color: theme.controlBorder
|
border.color: {
|
||||||
|
switch (errState) {
|
||||||
|
case "ok": return theme.controlBorder;
|
||||||
|
case "warning": return theme.textWarningColor;
|
||||||
|
case "error": return theme.textErrorColor;
|
||||||
|
}
|
||||||
|
}
|
||||||
radius: 10
|
radius: 10
|
||||||
}
|
}
|
||||||
padding: 10
|
padding: 10
|
||||||
wrapMode: TextArea.Wrap
|
wrapMode: TextArea.Wrap
|
||||||
|
|
||||||
ToolTip.delay: Qt.styleHints.mousePressAndHoldInterval
|
ToolTip.delay: Qt.styleHints.mousePressAndHoldInterval
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,7 @@ Button {
|
|||||||
property alias fillMode: image.fillMode
|
property alias fillMode: image.fillMode
|
||||||
property alias imageWidth: image.sourceSize.width
|
property alias imageWidth: image.sourceSize.width
|
||||||
property alias imageHeight: image.sourceSize.height
|
property alias imageHeight: image.sourceSize.height
|
||||||
|
property alias bgTransform: background.transform
|
||||||
contentItem: Text {
|
contentItem: Text {
|
||||||
text: myButton.text
|
text: myButton.text
|
||||||
horizontalAlignment: Text.AlignHCenter
|
horizontalAlignment: Text.AlignHCenter
|
||||||
@ -26,6 +27,7 @@ Button {
|
|||||||
}
|
}
|
||||||
|
|
||||||
background: Item {
|
background: Item {
|
||||||
|
id: background
|
||||||
anchors.fill: parent
|
anchors.fill: parent
|
||||||
Rectangle {
|
Rectangle {
|
||||||
anchors.fill: parent
|
anchors.fill: parent
|
||||||
|
@ -1,46 +0,0 @@
|
|||||||
import QtCore
|
|
||||||
import QtQuick
|
|
||||||
import QtQuick.Controls
|
|
||||||
import QtQuick.Controls.Basic
|
|
||||||
import QtQuick.Layouts
|
|
||||||
import llm
|
|
||||||
import mysettings
|
|
||||||
|
|
||||||
MyDialog {
|
|
||||||
id: switchModelDialog
|
|
||||||
anchors.centerIn: parent
|
|
||||||
modal: true
|
|
||||||
padding: 20
|
|
||||||
property int index: -1
|
|
||||||
|
|
||||||
Theme {
|
|
||||||
id: theme
|
|
||||||
}
|
|
||||||
|
|
||||||
contentItem: Text {
|
|
||||||
textFormat: Text.StyledText
|
|
||||||
text: qsTr("<b>Warning:</b> changing the model will erase the current conversation. Do you wish to continue?")
|
|
||||||
color: theme.textColor
|
|
||||||
font.pixelSize: theme.fontSizeLarge
|
|
||||||
}
|
|
||||||
|
|
||||||
footer: DialogButtonBox {
|
|
||||||
id: dialogBox
|
|
||||||
padding: 20
|
|
||||||
alignment: Qt.AlignRight
|
|
||||||
spacing: 10
|
|
||||||
MySettingsButton {
|
|
||||||
text: qsTr("Continue")
|
|
||||||
Accessible.description: qsTr("Continue with model loading")
|
|
||||||
DialogButtonBox.buttonRole: DialogButtonBox.AcceptRole
|
|
||||||
}
|
|
||||||
MySettingsButton {
|
|
||||||
text: qsTr("Cancel")
|
|
||||||
Accessible.description: qsTr("Cancel")
|
|
||||||
DialogButtonBox.buttonRole: DialogButtonBox.RejectRole
|
|
||||||
}
|
|
||||||
background: Rectangle {
|
|
||||||
color: "transparent"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -64,6 +64,9 @@ QtObject {
|
|||||||
property color green800: Qt.hsla(123/360, 0.17, 0.24)
|
property color green800: Qt.hsla(123/360, 0.17, 0.24)
|
||||||
property color green900: Qt.hsla(124/360, 0.17, 0.20)
|
property color green900: Qt.hsla(124/360, 0.17, 0.20)
|
||||||
property color green950: Qt.hsla(125/360, 0.22, 0.10)
|
property color green950: Qt.hsla(125/360, 0.22, 0.10)
|
||||||
|
property color green300_sat: Qt.hsla(122/360, 0.24, 0.73)
|
||||||
|
property color green400_sat: Qt.hsla(122/360, 0.23, 0.58)
|
||||||
|
property color green450_sat: Qt.hsla(122/360, 0.23, 0.52)
|
||||||
|
|
||||||
// yellow
|
// yellow
|
||||||
property color yellow0: Qt.hsla(47/360, 0.90, 0.99)
|
property color yellow0: Qt.hsla(47/360, 0.90, 0.99)
|
||||||
@ -99,6 +102,7 @@ QtObject {
|
|||||||
property color purple200: Qt.hsla(279/360, 1.0, 0.91)
|
property color purple200: Qt.hsla(279/360, 1.0, 0.91)
|
||||||
property color purple300: Qt.hsla(279/360, 1.0, 0.84)
|
property color purple300: Qt.hsla(279/360, 1.0, 0.84)
|
||||||
property color purple400: Qt.hsla(279/360, 1.0, 0.73)
|
property color purple400: Qt.hsla(279/360, 1.0, 0.73)
|
||||||
|
property color purple450: Qt.hsla(279/360, 1.0, 0.68)
|
||||||
property color purple500: Qt.hsla(279/360, 1.0, 0.63)
|
property color purple500: Qt.hsla(279/360, 1.0, 0.63)
|
||||||
property color purple600: Qt.hsla(279/360, 1.0, 0.53)
|
property color purple600: Qt.hsla(279/360, 1.0, 0.53)
|
||||||
property color purple700: Qt.hsla(279/360, 1.0, 0.47)
|
property color purple700: Qt.hsla(279/360, 1.0, 0.47)
|
||||||
@ -408,6 +412,39 @@ QtObject {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
property color mediumButtonBackground: {
|
||||||
|
switch (MySettings.chatTheme) {
|
||||||
|
case MySettingsEnums.ChatTheme.LegacyDark:
|
||||||
|
return purple400
|
||||||
|
case MySettingsEnums.ChatTheme.Dark:
|
||||||
|
return green400_sat
|
||||||
|
default:
|
||||||
|
return green400_sat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
property color mediumButtonBackgroundHovered: {
|
||||||
|
switch (MySettings.chatTheme) {
|
||||||
|
case MySettingsEnums.ChatTheme.LegacyDark:
|
||||||
|
return purple450
|
||||||
|
case MySettingsEnums.ChatTheme.Dark:
|
||||||
|
return green450_sat
|
||||||
|
default:
|
||||||
|
return green300_sat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
property color mediumButtonText: {
|
||||||
|
switch (MySettings.chatTheme) {
|
||||||
|
case MySettingsEnums.ChatTheme.LegacyDark:
|
||||||
|
return textColor
|
||||||
|
case MySettingsEnums.ChatTheme.Dark:
|
||||||
|
return textColor
|
||||||
|
default:
|
||||||
|
return white
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
property color darkButtonText: {
|
property color darkButtonText: {
|
||||||
switch (MySettings.chatTheme) {
|
switch (MySettings.chatTheme) {
|
||||||
case MySettingsEnums.ChatTheme.LegacyDark:
|
case MySettingsEnums.ChatTheme.LegacyDark:
|
||||||
@ -922,16 +959,8 @@ QtObject {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
property color textErrorColor: {
|
readonly property color textErrorColor: red400
|
||||||
switch (MySettings.chatTheme) {
|
readonly property color textWarningColor: yellow400
|
||||||
case MySettingsEnums.ChatTheme.LegacyDark:
|
|
||||||
return red400
|
|
||||||
case MySettingsEnums.ChatTheme.Dark:
|
|
||||||
return red400
|
|
||||||
default:
|
|
||||||
return red400
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
property color settingsTitleTextColor: {
|
property color settingsTitleTextColor: {
|
||||||
switch (MySettings.chatTheme) {
|
switch (MySettings.chatTheme) {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#include "chat.h"
|
#include "chat.h"
|
||||||
|
|
||||||
#include "chatlistmodel.h"
|
#include "chatlistmodel.h"
|
||||||
#include "mysettings.h"
|
|
||||||
#include "network.h"
|
#include "network.h"
|
||||||
#include "server.h"
|
#include "server.h"
|
||||||
|
|
||||||
@ -11,7 +10,6 @@
|
|||||||
#include <QLatin1String>
|
#include <QLatin1String>
|
||||||
#include <QMap>
|
#include <QMap>
|
||||||
#include <QString>
|
#include <QString>
|
||||||
#include <QStringList>
|
|
||||||
#include <QVariant>
|
#include <QVariant>
|
||||||
#include <Qt>
|
#include <Qt>
|
||||||
#include <QtLogging>
|
#include <QtLogging>
|
||||||
@ -56,18 +54,18 @@ void Chat::connectLLM()
|
|||||||
// Should be in different threads
|
// Should be in different threads
|
||||||
connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
|
||||||
|
connect(m_llmodel, &ChatLLM::responseFailed, this, &Chat::handleResponseFailed, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection);
|
|
||||||
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
|
||||||
|
|
||||||
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
|
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
|
||||||
@ -75,11 +73,10 @@ void Chat::connectLLM()
|
|||||||
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection);
|
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection);
|
||||||
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
|
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
|
||||||
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::QueuedConnection);
|
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::QueuedConnection);
|
||||||
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::QueuedConnection);
|
|
||||||
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection);
|
|
||||||
connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection);
|
|
||||||
|
|
||||||
connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections);
|
connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections);
|
||||||
|
|
||||||
|
connect(ModelList::globalInstance(), &ModelList::modelInfoChanged, this, &Chat::handleModelInfoChanged);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Chat::reset()
|
void Chat::reset()
|
||||||
@ -87,28 +84,17 @@ void Chat::reset()
|
|||||||
stopGenerating();
|
stopGenerating();
|
||||||
// Erase our current on disk representation as we're completely resetting the chat along with id
|
// Erase our current on disk representation as we're completely resetting the chat along with id
|
||||||
ChatListModel::globalInstance()->removeChatFile(this);
|
ChatListModel::globalInstance()->removeChatFile(this);
|
||||||
emit resetContextRequested();
|
|
||||||
m_id = Network::globalInstance()->generateUniqueId();
|
m_id = Network::globalInstance()->generateUniqueId();
|
||||||
emit idChanged(m_id);
|
emit idChanged(m_id);
|
||||||
// NOTE: We deliberately do no reset the name or creation date to indicate that this was originally
|
// NOTE: We deliberately do no reset the name or creation date to indicate that this was originally
|
||||||
// an older chat that was reset for another purpose. Resetting this data will lead to the chat
|
// an older chat that was reset for another purpose. Resetting this data will lead to the chat
|
||||||
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
|
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
|
||||||
// further down in the list. This might surprise the user. In the future, we might get rid of
|
// further down in the list. This might surprise the user. In the future, we might get rid of
|
||||||
// the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown
|
// the "reset context" button in the UI.
|
||||||
// we effectively do a reset context. We *have* to do this right now when switching between different
|
|
||||||
// types of models. The only way to get rid of that would be a very long recalculate where we rebuild
|
|
||||||
// the context if we switch between different types of models. Probably the right way to fix this
|
|
||||||
// is to allow switching models but throwing up a dialog warning users if we switch between types
|
|
||||||
// of models that a long recalculation will ensue.
|
|
||||||
m_chatModel->clear();
|
m_chatModel->clear();
|
||||||
m_needsSave = true;
|
m_needsSave = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Chat::processSystemPrompt()
|
|
||||||
{
|
|
||||||
emit processSystemPromptRequested();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Chat::resetResponseState()
|
void Chat::resetResponseState()
|
||||||
{
|
{
|
||||||
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
|
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
|
||||||
@ -160,25 +146,30 @@ void Chat::newPromptResponsePair(const QString &prompt, const QList<QUrl> &attac
|
|||||||
if (!attachedContexts.isEmpty())
|
if (!attachedContexts.isEmpty())
|
||||||
promptPlusAttached = attachedContexts.join("\n\n") + "\n\n" + prompt;
|
promptPlusAttached = attachedContexts.join("\n\n") + "\n\n" + prompt;
|
||||||
|
|
||||||
newPromptResponsePairInternal(prompt, attachments);
|
resetResponseState();
|
||||||
emit resetResponseRequested();
|
qsizetype prevMsgIndex = m_chatModel->count() - 1;
|
||||||
|
if (prevMsgIndex >= 0)
|
||||||
|
m_chatModel->updateCurrentResponse(prevMsgIndex, false);
|
||||||
|
m_chatModel->appendPrompt(prompt, attachments);
|
||||||
|
m_chatModel->appendResponse(prevMsgIndex + 1);
|
||||||
|
|
||||||
this->prompt(promptPlusAttached);
|
emit promptRequested(m_collections);
|
||||||
|
m_needsSave = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Chat::prompt(const QString &prompt)
|
void Chat::regenerateResponse(int index)
|
||||||
{
|
{
|
||||||
resetResponseState();
|
resetResponseState();
|
||||||
emit promptRequested(m_collections, prompt);
|
emit regenerateResponseRequested(index);
|
||||||
m_needsSave = true;
|
m_needsSave = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Chat::regenerateResponse()
|
QVariant Chat::popPrompt(int index)
|
||||||
{
|
{
|
||||||
const int index = m_chatModel->count() - 1;
|
auto content = m_llmodel->popPrompt(index);
|
||||||
m_chatModel->updateSources(index, QList<ResultInfo>());
|
|
||||||
emit regenerateResponseRequested();
|
|
||||||
m_needsSave = true;
|
m_needsSave = true;
|
||||||
|
if (content) return *content;
|
||||||
|
return QVariant::fromValue(nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Chat::stopGenerating()
|
void Chat::stopGenerating()
|
||||||
@ -202,6 +193,14 @@ void Chat::handleResponseChanged(const QString &response)
|
|||||||
m_chatModel->updateValue(index, response);
|
m_chatModel->updateValue(index, response);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Chat::handleResponseFailed(const QString &error)
|
||||||
|
{
|
||||||
|
const int index = m_chatModel->count() - 1;
|
||||||
|
m_chatModel->updateValue(index, error);
|
||||||
|
m_chatModel->setError();
|
||||||
|
responseStopped(0);
|
||||||
|
}
|
||||||
|
|
||||||
void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
|
void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
|
||||||
{
|
{
|
||||||
if (m_shouldDeleteLater)
|
if (m_shouldDeleteLater)
|
||||||
@ -272,25 +271,6 @@ void Chat::setModelInfo(const ModelInfo &modelInfo)
|
|||||||
emit modelChangeRequested(modelInfo);
|
emit modelChangeRequested(modelInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
// the server needs to block until response is reset, so it calls resetResponse on its own m_llmThread
|
|
||||||
void Chat::serverNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments)
|
|
||||||
{
|
|
||||||
newPromptResponsePairInternal(prompt, attachments);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Chat::newPromptResponsePairInternal(const QString &prompt, const QList<PromptAttachment> &attachments)
|
|
||||||
{
|
|
||||||
resetResponseState();
|
|
||||||
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
|
|
||||||
m_chatModel->appendPrompt("Prompt: ", prompt, attachments);
|
|
||||||
m_chatModel->appendResponse("Response: ");
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Chat::restoringFromText() const
|
|
||||||
{
|
|
||||||
return m_llmodel->restoringFromText();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Chat::unloadAndDeleteLater()
|
void Chat::unloadAndDeleteLater()
|
||||||
{
|
{
|
||||||
if (!isModelLoaded()) {
|
if (!isModelLoaded()) {
|
||||||
@ -356,12 +336,6 @@ void Chat::generatedQuestionFinished(const QString &question)
|
|||||||
m_needsSave = true;
|
m_needsSave = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Chat::handleRestoringFromText()
|
|
||||||
{
|
|
||||||
Network::globalInstance()->trackChatEvent("recalc_context", { {"length", m_chatModel->count()} });
|
|
||||||
emit restoringFromTextChanged();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Chat::handleModelLoadingError(const QString &error)
|
void Chat::handleModelLoadingError(const QString &error)
|
||||||
{
|
{
|
||||||
if (!error.isEmpty()) {
|
if (!error.isEmpty()) {
|
||||||
@ -396,12 +370,19 @@ QString Chat::fallbackReason() const
|
|||||||
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
|
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
|
||||||
{
|
{
|
||||||
m_databaseResults = results;
|
m_databaseResults = results;
|
||||||
const int index = m_chatModel->count() - 1;
|
|
||||||
m_chatModel->updateSources(index, m_databaseResults);
|
|
||||||
m_needsSave = true;
|
m_needsSave = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we need to notify listeners of the modelInfo property when its properties are updated,
|
||||||
|
// since it's a gadget and can't do that on its own
|
||||||
void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
|
void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
|
||||||
|
{
|
||||||
|
if (!m_modelInfo.id().isNull() && modelInfo.id() == m_modelInfo.id())
|
||||||
|
emit modelInfoChanged();
|
||||||
|
}
|
||||||
|
|
||||||
|
// react if a new model is loaded
|
||||||
|
void Chat::handleModelChanged(const ModelInfo &modelInfo)
|
||||||
{
|
{
|
||||||
if (m_modelInfo == modelInfo)
|
if (m_modelInfo == modelInfo)
|
||||||
return;
|
return;
|
||||||
@ -430,10 +411,7 @@ bool Chat::serialize(QDataStream &stream, int version) const
|
|||||||
if (version >= 3)
|
if (version >= 3)
|
||||||
stream << m_collections;
|
stream << m_collections;
|
||||||
|
|
||||||
const bool serializeKV = MySettings::globalInstance()->saveChatsContext();
|
if (!m_llmodel->serialize(stream, version))
|
||||||
if (version >= 6)
|
|
||||||
stream << serializeKV;
|
|
||||||
if (!m_llmodel->serialize(stream, version, serializeKV))
|
|
||||||
return false;
|
return false;
|
||||||
if (!m_chatModel->serialize(stream, version))
|
if (!m_chatModel->serialize(stream, version))
|
||||||
return false;
|
return false;
|
||||||
@ -462,19 +440,13 @@ bool Chat::deserialize(QDataStream &stream, int version)
|
|||||||
if (!m_modelInfo.id().isEmpty())
|
if (!m_modelInfo.id().isEmpty())
|
||||||
emit modelInfoChanged();
|
emit modelInfoChanged();
|
||||||
|
|
||||||
bool discardKV = m_modelInfo.id().isEmpty();
|
|
||||||
|
|
||||||
if (version >= 3) {
|
if (version >= 3) {
|
||||||
stream >> m_collections;
|
stream >> m_collections;
|
||||||
emit collectionListChanged(m_collections);
|
emit collectionListChanged(m_collections);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool deserializeKV = true;
|
|
||||||
if (version >= 6)
|
|
||||||
stream >> deserializeKV;
|
|
||||||
|
|
||||||
m_llmodel->setModelInfo(m_modelInfo);
|
m_llmodel->setModelInfo(m_modelInfo);
|
||||||
if (!m_llmodel->deserialize(stream, version, deserializeKV, discardKV))
|
if (!m_llmodel->deserialize(stream, version))
|
||||||
return false;
|
return false;
|
||||||
if (!m_chatModel->deserialize(stream, version))
|
if (!m_chatModel->deserialize(stream, version))
|
||||||
return false;
|
return false;
|
||||||
|
@ -12,6 +12,8 @@
|
|||||||
#include <QObject>
|
#include <QObject>
|
||||||
#include <QQmlEngine>
|
#include <QQmlEngine>
|
||||||
#include <QString>
|
#include <QString>
|
||||||
|
#include <QStringList> // IWYU pragma: keep
|
||||||
|
#include <QStringView>
|
||||||
#include <QtGlobal>
|
#include <QtGlobal>
|
||||||
|
|
||||||
class QDataStream;
|
class QDataStream;
|
||||||
@ -27,7 +29,6 @@ class Chat : public QObject
|
|||||||
Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged)
|
Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged)
|
||||||
Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged)
|
Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged)
|
||||||
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
||||||
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
|
|
||||||
Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged)
|
Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged)
|
||||||
Q_PROPERTY(ResponseState responseState READ responseState NOTIFY responseStateChanged)
|
Q_PROPERTY(ResponseState responseState READ responseState NOTIFY responseStateChanged)
|
||||||
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
|
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
|
||||||
@ -77,13 +78,12 @@ public:
|
|||||||
bool isNewChat() const { return m_name == tr("New Chat") && !m_chatModel->count(); }
|
bool isNewChat() const { return m_name == tr("New Chat") && !m_chatModel->count(); }
|
||||||
|
|
||||||
Q_INVOKABLE void reset();
|
Q_INVOKABLE void reset();
|
||||||
Q_INVOKABLE void processSystemPrompt();
|
|
||||||
bool isModelLoaded() const { return m_modelLoadingPercentage == 1.0f; }
|
bool isModelLoaded() const { return m_modelLoadingPercentage == 1.0f; }
|
||||||
bool isCurrentlyLoading() const { return m_modelLoadingPercentage > 0.0f && m_modelLoadingPercentage < 1.0f; }
|
bool isCurrentlyLoading() const { return m_modelLoadingPercentage > 0.0f && m_modelLoadingPercentage < 1.0f; }
|
||||||
float modelLoadingPercentage() const { return m_modelLoadingPercentage; }
|
float modelLoadingPercentage() const { return m_modelLoadingPercentage; }
|
||||||
Q_INVOKABLE void newPromptResponsePair(const QString &prompt, const QList<QUrl> &attachedUrls = {});
|
Q_INVOKABLE void newPromptResponsePair(const QString &prompt, const QList<QUrl> &attachedUrls = {});
|
||||||
Q_INVOKABLE void prompt(const QString &prompt);
|
Q_INVOKABLE void regenerateResponse(int index);
|
||||||
Q_INVOKABLE void regenerateResponse();
|
Q_INVOKABLE QVariant popPrompt(int index);
|
||||||
Q_INVOKABLE void stopGenerating();
|
Q_INVOKABLE void stopGenerating();
|
||||||
|
|
||||||
QList<ResultInfo> databaseResults() const { return m_databaseResults; }
|
QList<ResultInfo> databaseResults() const { return m_databaseResults; }
|
||||||
@ -92,7 +92,6 @@ public:
|
|||||||
ResponseState responseState() const;
|
ResponseState responseState() const;
|
||||||
ModelInfo modelInfo() const;
|
ModelInfo modelInfo() const;
|
||||||
void setModelInfo(const ModelInfo &modelInfo);
|
void setModelInfo(const ModelInfo &modelInfo);
|
||||||
bool restoringFromText() const;
|
|
||||||
|
|
||||||
Q_INVOKABLE void unloadModel();
|
Q_INVOKABLE void unloadModel();
|
||||||
Q_INVOKABLE void reloadModel();
|
Q_INVOKABLE void reloadModel();
|
||||||
@ -113,7 +112,6 @@ public:
|
|||||||
Q_INVOKABLE bool hasCollection(const QString &collection) const;
|
Q_INVOKABLE bool hasCollection(const QString &collection) const;
|
||||||
Q_INVOKABLE void addCollection(const QString &collection);
|
Q_INVOKABLE void addCollection(const QString &collection);
|
||||||
Q_INVOKABLE void removeCollection(const QString &collection);
|
Q_INVOKABLE void removeCollection(const QString &collection);
|
||||||
void resetResponseState();
|
|
||||||
|
|
||||||
QString modelLoadingError() const { return m_modelLoadingError; }
|
QString modelLoadingError() const { return m_modelLoadingError; }
|
||||||
|
|
||||||
@ -131,7 +129,7 @@ public:
|
|||||||
void setNeedsSave(bool n) { m_needsSave = n; }
|
void setNeedsSave(bool n) { m_needsSave = n; }
|
||||||
|
|
||||||
public Q_SLOTS:
|
public Q_SLOTS:
|
||||||
void serverNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments = {});
|
void resetResponseState();
|
||||||
|
|
||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void idChanged(const QString &id);
|
void idChanged(const QString &id);
|
||||||
@ -143,14 +141,12 @@ Q_SIGNALS:
|
|||||||
void modelLoadingWarning(const QString &warning);
|
void modelLoadingWarning(const QString &warning);
|
||||||
void responseInProgressChanged();
|
void responseInProgressChanged();
|
||||||
void responseStateChanged();
|
void responseStateChanged();
|
||||||
void promptRequested(const QList<QString> &collectionList, const QString &prompt);
|
void promptRequested(const QStringList &enabledCollections);
|
||||||
void regenerateResponseRequested();
|
void regenerateResponseRequested(int index);
|
||||||
void resetResponseRequested();
|
void resetResponseRequested();
|
||||||
void resetContextRequested();
|
void resetContextRequested();
|
||||||
void processSystemPromptRequested();
|
|
||||||
void modelChangeRequested(const ModelInfo &modelInfo);
|
void modelChangeRequested(const ModelInfo &modelInfo);
|
||||||
void modelInfoChanged();
|
void modelInfoChanged();
|
||||||
void restoringFromTextChanged();
|
|
||||||
void loadDefaultModelRequested();
|
void loadDefaultModelRequested();
|
||||||
void generateNameRequested();
|
void generateNameRequested();
|
||||||
void modelLoadingErrorChanged();
|
void modelLoadingErrorChanged();
|
||||||
@ -166,22 +162,20 @@ Q_SIGNALS:
|
|||||||
|
|
||||||
private Q_SLOTS:
|
private Q_SLOTS:
|
||||||
void handleResponseChanged(const QString &response);
|
void handleResponseChanged(const QString &response);
|
||||||
|
void handleResponseFailed(const QString &error);
|
||||||
void handleModelLoadingPercentageChanged(float);
|
void handleModelLoadingPercentageChanged(float);
|
||||||
void promptProcessing();
|
void promptProcessing();
|
||||||
void generatingQuestions();
|
void generatingQuestions();
|
||||||
void responseStopped(qint64 promptResponseMs);
|
void responseStopped(qint64 promptResponseMs);
|
||||||
void generatedNameChanged(const QString &name);
|
void generatedNameChanged(const QString &name);
|
||||||
void generatedQuestionFinished(const QString &question);
|
void generatedQuestionFinished(const QString &question);
|
||||||
void handleRestoringFromText();
|
|
||||||
void handleModelLoadingError(const QString &error);
|
void handleModelLoadingError(const QString &error);
|
||||||
void handleTokenSpeedChanged(const QString &tokenSpeed);
|
void handleTokenSpeedChanged(const QString &tokenSpeed);
|
||||||
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
|
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
|
||||||
void handleModelInfoChanged(const ModelInfo &modelInfo);
|
void handleModelInfoChanged(const ModelInfo &modelInfo);
|
||||||
|
void handleModelChanged(const ModelInfo &modelInfo);
|
||||||
void handleTrySwitchContextOfLoadedModelCompleted(int value);
|
void handleTrySwitchContextOfLoadedModelCompleted(int value);
|
||||||
|
|
||||||
private:
|
|
||||||
void newPromptResponsePairInternal(const QString &prompt, const QList<PromptAttachment> &attachments);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
QString m_id;
|
QString m_id;
|
||||||
QString m_name;
|
QString m_name;
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
#include "chatapi.h"
|
#include "chatapi.h"
|
||||||
|
|
||||||
#include <gpt4all-backend/llmodel.h>
|
#include "utils.h"
|
||||||
|
|
||||||
#include <QCoreApplication>
|
#include <QCoreApplication>
|
||||||
#include <QGuiApplication>
|
|
||||||
#include <QDebug>
|
#include <QDebug>
|
||||||
|
#include <QGuiApplication>
|
||||||
#include <QJsonArray>
|
#include <QJsonArray>
|
||||||
#include <QJsonDocument>
|
#include <QJsonDocument>
|
||||||
#include <QJsonObject>
|
#include <QJsonObject>
|
||||||
@ -13,12 +13,17 @@
|
|||||||
#include <QNetworkRequest>
|
#include <QNetworkRequest>
|
||||||
#include <QThread>
|
#include <QThread>
|
||||||
#include <QUrl>
|
#include <QUrl>
|
||||||
|
#include <QUtf8StringView>
|
||||||
#include <QVariant>
|
#include <QVariant>
|
||||||
|
#include <QXmlStreamReader>
|
||||||
#include <Qt>
|
#include <Qt>
|
||||||
#include <QtGlobal>
|
#include <QtGlobal>
|
||||||
#include <QtLogging>
|
#include <QtLogging>
|
||||||
|
|
||||||
|
#include <expected>
|
||||||
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
using namespace Qt::Literals::StringLiterals;
|
using namespace Qt::Literals::StringLiterals;
|
||||||
|
|
||||||
@ -67,71 +72,119 @@ bool ChatAPI::isModelLoaded() const
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ChatAPI::prompt(const std::string &prompt,
|
static auto parsePrompt(QXmlStreamReader &xml) -> std::expected<QJsonArray, QString>
|
||||||
const std::string &promptTemplate,
|
{
|
||||||
std::function<bool(int32_t)> promptCallback,
|
QJsonArray messages;
|
||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
|
||||||
bool allowContextShift,
|
|
||||||
PromptContext &promptCtx,
|
|
||||||
bool special,
|
|
||||||
std::optional<std::string_view> fakeReply) {
|
|
||||||
|
|
||||||
Q_UNUSED(promptCallback);
|
auto xmlError = [&xml] {
|
||||||
Q_UNUSED(allowContextShift);
|
return std::unexpected(u"%1:%2: %3"_s.arg(xml.lineNumber()).arg(xml.columnNumber()).arg(xml.errorString()));
|
||||||
Q_UNUSED(special);
|
};
|
||||||
|
|
||||||
if (!isModelLoaded()) {
|
if (xml.hasError())
|
||||||
std::cerr << "ChatAPI ERROR: prompt won't work with an unloaded model!\n";
|
return xmlError();
|
||||||
return;
|
if (xml.atEnd())
|
||||||
|
return messages;
|
||||||
|
|
||||||
|
// skip header
|
||||||
|
bool foundElement = false;
|
||||||
|
do {
|
||||||
|
switch (xml.readNext()) {
|
||||||
|
using enum QXmlStreamReader::TokenType;
|
||||||
|
case Invalid:
|
||||||
|
return xmlError();
|
||||||
|
case EndDocument:
|
||||||
|
return messages;
|
||||||
|
default:
|
||||||
|
foundElement = true;
|
||||||
|
case StartDocument:
|
||||||
|
case Comment:
|
||||||
|
case DTD:
|
||||||
|
case ProcessingInstruction:
|
||||||
|
;
|
||||||
|
}
|
||||||
|
} while (!foundElement);
|
||||||
|
|
||||||
|
// document body loop
|
||||||
|
bool foundRoot = false;
|
||||||
|
for (;;) {
|
||||||
|
switch (xml.tokenType()) {
|
||||||
|
using enum QXmlStreamReader::TokenType;
|
||||||
|
case StartElement:
|
||||||
|
{
|
||||||
|
auto name = xml.name();
|
||||||
|
if (!foundRoot) {
|
||||||
|
if (name != "chat"_L1)
|
||||||
|
return std::unexpected(u"unexpected tag: %1"_s.arg(name));
|
||||||
|
foundRoot = true;
|
||||||
|
} else {
|
||||||
|
if (name != "user"_L1 && name != "assistant"_L1 && name != "system"_L1)
|
||||||
|
return std::unexpected(u"unknown role: %1"_s.arg(name));
|
||||||
|
auto content = xml.readElementText();
|
||||||
|
if (xml.tokenType() != EndElement)
|
||||||
|
return xmlError();
|
||||||
|
messages << makeJsonObject({
|
||||||
|
{ "role"_L1, name.toString().trimmed() },
|
||||||
|
{ "content"_L1, content },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Characters:
|
||||||
|
if (!xml.isWhitespace())
|
||||||
|
return std::unexpected(u"unexpected text: %1"_s.arg(xml.text()));
|
||||||
|
case Comment:
|
||||||
|
case ProcessingInstruction:
|
||||||
|
case EndElement:
|
||||||
|
break;
|
||||||
|
case EndDocument:
|
||||||
|
return messages;
|
||||||
|
case Invalid:
|
||||||
|
return xmlError();
|
||||||
|
default:
|
||||||
|
return std::unexpected(u"unexpected token: %1"_s.arg(xml.tokenString()));
|
||||||
|
}
|
||||||
|
xml.readNext();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!promptCtx.n_past) { m_queuedPrompts.clear(); }
|
void ChatAPI::prompt(
|
||||||
Q_ASSERT(promptCtx.n_past <= m_context.size());
|
std::string_view prompt,
|
||||||
m_context.resize(promptCtx.n_past);
|
const PromptCallback &promptCallback,
|
||||||
|
const ResponseCallback &responseCallback,
|
||||||
|
const PromptContext &promptCtx
|
||||||
|
) {
|
||||||
|
Q_UNUSED(promptCallback)
|
||||||
|
|
||||||
// FIXME(cebtenzzre): We're assuming people don't try to use %2 with ChatGPT. What would that even mean?
|
if (!isModelLoaded())
|
||||||
m_queuedPrompts << QString::fromStdString(promptTemplate).arg(QString::fromStdString(prompt));
|
throw std::invalid_argument("Attempted to prompt an unloaded model.");
|
||||||
|
if (!promptCtx.n_predict)
|
||||||
if (!promptCtx.n_predict && !fakeReply) {
|
return; // nothing requested
|
||||||
return; // response explicitly suppressed, queue prompt for later
|
|
||||||
}
|
|
||||||
|
|
||||||
QString formattedPrompt = m_queuedPrompts.join("");
|
|
||||||
m_queuedPrompts.clear();
|
|
||||||
|
|
||||||
if (fakeReply) {
|
|
||||||
promptCtx.n_past += 1;
|
|
||||||
m_context.append(formattedPrompt);
|
|
||||||
m_context.append(QString::fromUtf8(fakeReply->data(), fakeReply->size()));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering
|
// FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering
|
||||||
// an error we need to be able to count the tokens in our prompt. The only way to do this is to use
|
// an error we need to be able to count the tokens in our prompt. The only way to do this is to use
|
||||||
// the OpenAI tiktokken library or to implement our own tokenization function that matches precisely
|
// the OpenAI tiktoken library or to implement our own tokenization function that matches precisely
|
||||||
// the tokenization used by the OpenAI model we're calling. OpenAI has not introduced any means of
|
// the tokenization used by the OpenAI model we're calling. OpenAI has not introduced any means of
|
||||||
// using the REST API to count tokens in a prompt.
|
// using the REST API to count tokens in a prompt.
|
||||||
QJsonObject root;
|
auto root = makeJsonObject({
|
||||||
root.insert("model", m_modelName);
|
{ "model"_L1, m_modelName },
|
||||||
root.insert("stream", true);
|
{ "stream"_L1, true },
|
||||||
root.insert("temperature", promptCtx.temp);
|
{ "temperature"_L1, promptCtx.temp },
|
||||||
root.insert("top_p", promptCtx.top_p);
|
{ "top_p"_L1, promptCtx.top_p },
|
||||||
|
});
|
||||||
|
|
||||||
// conversation history
|
// conversation history
|
||||||
QJsonArray messages;
|
{
|
||||||
for (int i = 0; i < m_context.count(); ++i) {
|
QUtf8StringView promptUtf8(prompt);
|
||||||
QJsonObject message;
|
QXmlStreamReader xml(promptUtf8);
|
||||||
message.insert("role", i % 2 == 0 ? "user" : "assistant");
|
auto messages = parsePrompt(xml);
|
||||||
message.insert("content", m_context.at(i));
|
if (!messages) {
|
||||||
messages.append(message);
|
auto error = fmt::format("Failed to parse API model prompt: {}", messages.error());
|
||||||
|
qDebug().noquote() << "ChatAPI ERROR:" << error << "Prompt:\n\n" << promptUtf8 << '\n';
|
||||||
|
throw std::invalid_argument(error);
|
||||||
|
}
|
||||||
|
root.insert("messages"_L1, *messages);
|
||||||
}
|
}
|
||||||
|
|
||||||
QJsonObject promptObject;
|
|
||||||
promptObject.insert("role", "user");
|
|
||||||
promptObject.insert("content", formattedPrompt);
|
|
||||||
messages.append(promptObject);
|
|
||||||
root.insert("messages", messages);
|
|
||||||
|
|
||||||
QJsonDocument doc(root);
|
QJsonDocument doc(root);
|
||||||
|
|
||||||
#if defined(DEBUG)
|
#if defined(DEBUG)
|
||||||
@ -148,12 +201,9 @@ void ChatAPI::prompt(const std::string &prompt,
|
|||||||
connect(&worker, &ChatAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
|
connect(&worker, &ChatAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
|
||||||
connect(this, &ChatAPI::request, &worker, &ChatAPIWorker::request, Qt::QueuedConnection);
|
connect(this, &ChatAPI::request, &worker, &ChatAPIWorker::request, Qt::QueuedConnection);
|
||||||
workerThread.start();
|
workerThread.start();
|
||||||
emit request(m_apiKey, &promptCtx, doc.toJson(QJsonDocument::Compact));
|
emit request(m_apiKey, doc.toJson(QJsonDocument::Compact));
|
||||||
workerThread.wait();
|
workerThread.wait();
|
||||||
|
|
||||||
promptCtx.n_past += 1;
|
|
||||||
m_context.append(formattedPrompt);
|
|
||||||
m_context.append(worker.currentResponse());
|
|
||||||
m_responseCallback = nullptr;
|
m_responseCallback = nullptr;
|
||||||
|
|
||||||
#if defined(DEBUG)
|
#if defined(DEBUG)
|
||||||
@ -171,12 +221,8 @@ bool ChatAPI::callResponse(int32_t token, const std::string& string)
|
|||||||
return m_responseCallback(token, string);
|
return m_responseCallback(token, string);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ChatAPIWorker::request(const QString &apiKey,
|
void ChatAPIWorker::request(const QString &apiKey, const QByteArray &array)
|
||||||
LLModel::PromptContext *promptCtx,
|
|
||||||
const QByteArray &array)
|
|
||||||
{
|
{
|
||||||
m_ctx = promptCtx;
|
|
||||||
|
|
||||||
QUrl apiUrl(m_chat->url());
|
QUrl apiUrl(m_chat->url());
|
||||||
const QString authorization = u"Bearer %1"_s.arg(apiKey).trimmed();
|
const QString authorization = u"Bearer %1"_s.arg(apiKey).trimmed();
|
||||||
QNetworkRequest request(apiUrl);
|
QNetworkRequest request(apiUrl);
|
||||||
@ -283,7 +329,6 @@ void ChatAPIWorker::handleReadyRead()
|
|||||||
const QJsonObject choice = choices.first().toObject();
|
const QJsonObject choice = choices.first().toObject();
|
||||||
const QJsonObject delta = choice.value("delta").toObject();
|
const QJsonObject delta = choice.value("delta").toObject();
|
||||||
const QString content = delta.value("content").toString();
|
const QString content = delta.value("content").toString();
|
||||||
Q_ASSERT(m_ctx);
|
|
||||||
m_currentResponse += content;
|
m_currentResponse += content;
|
||||||
if (!m_chat->callResponse(0, content.toStdString())) {
|
if (!m_chat->callResponse(0, content.toStdString())) {
|
||||||
reply->abort();
|
reply->abort();
|
||||||
|
@ -7,17 +7,14 @@
|
|||||||
#include <QNetworkReply>
|
#include <QNetworkReply>
|
||||||
#include <QObject>
|
#include <QObject>
|
||||||
#include <QString>
|
#include <QString>
|
||||||
#include <QStringList>
|
|
||||||
#include <QList>
|
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <functional>
|
|
||||||
#include <optional>
|
|
||||||
#include <span>
|
#include <span>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
class QNetworkAccessManager;
|
class QNetworkAccessManager;
|
||||||
@ -28,16 +25,13 @@ class ChatAPIWorker : public QObject {
|
|||||||
public:
|
public:
|
||||||
ChatAPIWorker(ChatAPI *chatAPI)
|
ChatAPIWorker(ChatAPI *chatAPI)
|
||||||
: QObject(nullptr)
|
: QObject(nullptr)
|
||||||
, m_ctx(nullptr)
|
|
||||||
, m_networkManager(nullptr)
|
, m_networkManager(nullptr)
|
||||||
, m_chat(chatAPI) {}
|
, m_chat(chatAPI) {}
|
||||||
virtual ~ChatAPIWorker() {}
|
virtual ~ChatAPIWorker() {}
|
||||||
|
|
||||||
QString currentResponse() const { return m_currentResponse; }
|
QString currentResponse() const { return m_currentResponse; }
|
||||||
|
|
||||||
void request(const QString &apiKey,
|
void request(const QString &apiKey, const QByteArray &array);
|
||||||
LLModel::PromptContext *promptCtx,
|
|
||||||
const QByteArray &array);
|
|
||||||
|
|
||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void finished();
|
void finished();
|
||||||
@ -49,7 +43,6 @@ private Q_SLOTS:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
ChatAPI *m_chat;
|
ChatAPI *m_chat;
|
||||||
LLModel::PromptContext *m_ctx;
|
|
||||||
QNetworkAccessManager *m_networkManager;
|
QNetworkAccessManager *m_networkManager;
|
||||||
QString m_currentResponse;
|
QString m_currentResponse;
|
||||||
};
|
};
|
||||||
@ -74,14 +67,14 @@ public:
|
|||||||
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override
|
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override
|
||||||
{ Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); }
|
{ Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); }
|
||||||
|
|
||||||
void prompt(const std::string &prompt,
|
void prompt(std::string_view prompt,
|
||||||
const std::string &promptTemplate,
|
const PromptCallback &promptCallback,
|
||||||
std::function<bool(int32_t)> promptCallback,
|
const ResponseCallback &responseCallback,
|
||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
const PromptContext &ctx) override;
|
||||||
bool allowContextShift,
|
|
||||||
PromptContext &ctx,
|
[[noreturn]]
|
||||||
bool special,
|
int32_t countPromptTokens(std::string_view prompt) const override
|
||||||
std::optional<std::string_view> fakeReply) override;
|
{ Q_UNUSED(prompt); throwNotImplemented(); }
|
||||||
|
|
||||||
void setThreadCount(int32_t n_threads) override;
|
void setThreadCount(int32_t n_threads) override;
|
||||||
int32_t threadCount() const override;
|
int32_t threadCount() const override;
|
||||||
@ -91,19 +84,17 @@ public:
|
|||||||
void setRequestURL(const QString &requestURL) { m_requestURL = requestURL; }
|
void setRequestURL(const QString &requestURL) { m_requestURL = requestURL; }
|
||||||
QString url() const { return m_requestURL; }
|
QString url() const { return m_requestURL; }
|
||||||
|
|
||||||
QList<QString> context() const { return m_context; }
|
|
||||||
void setContext(const QList<QString> &context) { m_context = context; }
|
|
||||||
|
|
||||||
bool callResponse(int32_t token, const std::string &string);
|
bool callResponse(int32_t token, const std::string &string);
|
||||||
|
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
int32_t contextLength() const override
|
int32_t contextLength() const override
|
||||||
{ throwNotImplemented(); }
|
{ throwNotImplemented(); }
|
||||||
|
|
||||||
|
auto specialTokens() -> std::unordered_map<std::string, std::string> const override
|
||||||
|
{ return {}; }
|
||||||
|
|
||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void request(const QString &apiKey,
|
void request(const QString &apiKey, const QByteArray &array);
|
||||||
LLModel::PromptContext *ctx,
|
|
||||||
const QByteArray &array);
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// We have to implement these as they are pure virtual in base class, but we don't actually use
|
// We have to implement these as they are pure virtual in base class, but we don't actually use
|
||||||
@ -114,8 +105,8 @@ protected:
|
|||||||
static void throwNotImplemented() { throw std::logic_error("not implemented"); }
|
static void throwNotImplemented() { throw std::logic_error("not implemented"); }
|
||||||
|
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
std::vector<Token> tokenize(std::string_view str, bool special) override
|
std::vector<Token> tokenize(std::string_view str) const override
|
||||||
{ Q_UNUSED(str); Q_UNUSED(special); throwNotImplemented(); }
|
{ Q_UNUSED(str); throwNotImplemented(); }
|
||||||
|
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
bool isSpecialToken(Token id) const override
|
bool isSpecialToken(Token id) const override
|
||||||
@ -126,7 +117,7 @@ protected:
|
|||||||
{ Q_UNUSED(id); throwNotImplemented(); }
|
{ Q_UNUSED(id); throwNotImplemented(); }
|
||||||
|
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
void initSampler(PromptContext &ctx) override
|
void initSampler(const PromptContext &ctx) override
|
||||||
{ Q_UNUSED(ctx); throwNotImplemented(); }
|
{ Q_UNUSED(ctx); throwNotImplemented(); }
|
||||||
|
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
@ -134,33 +125,28 @@ protected:
|
|||||||
{ throwNotImplemented(); }
|
{ throwNotImplemented(); }
|
||||||
|
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override
|
bool evalTokens(int32_t nPast, std::span<const Token> tokens) const override
|
||||||
{ Q_UNUSED(ctx); Q_UNUSED(tokens); throwNotImplemented(); }
|
{ Q_UNUSED(nPast); Q_UNUSED(tokens); throwNotImplemented(); }
|
||||||
|
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
void shiftContext(PromptContext &promptCtx) override
|
void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override
|
||||||
{ Q_UNUSED(promptCtx); throwNotImplemented(); }
|
{ Q_UNUSED(promptCtx); Q_UNUSED(nPast); throwNotImplemented(); }
|
||||||
|
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
int32_t inputLength() const override
|
int32_t inputLength() const override
|
||||||
{ throwNotImplemented(); }
|
{ throwNotImplemented(); }
|
||||||
|
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
void setTokenizeInputPosition(int32_t pos) override
|
int32_t computeModelInputPosition(std::span<const Token> input) const override
|
||||||
|
{ Q_UNUSED(input); throwNotImplemented(); }
|
||||||
|
|
||||||
|
[[noreturn]]
|
||||||
|
void setModelInputPosition(int32_t pos) override
|
||||||
{ Q_UNUSED(pos); throwNotImplemented(); }
|
{ Q_UNUSED(pos); throwNotImplemented(); }
|
||||||
|
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
|
void appendInputToken(Token tok) override
|
||||||
-> std::vector<Token>::const_iterator override
|
{ Q_UNUSED(tok); throwNotImplemented(); }
|
||||||
{ Q_UNUSED(ctx); Q_UNUSED(input); throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
void setModelInputPosition(PromptContext &ctx, int32_t pos) override
|
|
||||||
{ Q_UNUSED(ctx); Q_UNUSED(pos); throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
void appendInputToken(PromptContext &ctx, Token tok) override
|
|
||||||
{ Q_UNUSED(ctx); Q_UNUSED(tok); throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
const std::vector<Token> &endTokens() const override
|
const std::vector<Token> &endTokens() const override
|
||||||
@ -175,12 +161,10 @@ protected:
|
|||||||
{ throwNotImplemented(); }
|
{ throwNotImplemented(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<bool(int32_t, const std::string&)> m_responseCallback;
|
ResponseCallback m_responseCallback;
|
||||||
QString m_modelName;
|
QString m_modelName;
|
||||||
QString m_apiKey;
|
QString m_apiKey;
|
||||||
QString m_requestURL;
|
QString m_requestURL;
|
||||||
QList<QString> m_context;
|
|
||||||
QStringList m_queuedPrompts;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // CHATAPI_H
|
#endif // CHATAPI_H
|
||||||
|
@ -17,9 +17,10 @@
|
|||||||
#include <Qt>
|
#include <Qt>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#define CHAT_FORMAT_MAGIC 0xF5D553CC
|
static constexpr quint32 CHAT_FORMAT_MAGIC = 0xF5D553CC;
|
||||||
#define CHAT_FORMAT_VERSION 10
|
static constexpr qint32 CHAT_FORMAT_VERSION = 11;
|
||||||
|
|
||||||
class MyChatListModel: public ChatListModel { };
|
class MyChatListModel: public ChatListModel { };
|
||||||
Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance)
|
Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance)
|
||||||
@ -118,8 +119,8 @@ void ChatSaver::saveChats(const QVector<Chat *> &chats)
|
|||||||
}
|
}
|
||||||
QDataStream out(&tempFile);
|
QDataStream out(&tempFile);
|
||||||
|
|
||||||
out << (quint32)CHAT_FORMAT_MAGIC;
|
out << CHAT_FORMAT_MAGIC;
|
||||||
out << (qint32)CHAT_FORMAT_VERSION;
|
out << CHAT_FORMAT_VERSION;
|
||||||
out.setVersion(QDataStream::Qt_6_2);
|
out.setVersion(QDataStream::Qt_6_2);
|
||||||
|
|
||||||
qDebug() << "serializing chat" << fileName;
|
qDebug() << "serializing chat" << fileName;
|
||||||
@ -257,12 +258,15 @@ void ChatsRestoreThread::run()
|
|||||||
|
|
||||||
qDebug() << "deserializing chat" << f.file;
|
qDebug() << "deserializing chat" << f.file;
|
||||||
|
|
||||||
Chat *chat = new Chat;
|
auto chat = std::make_unique<Chat>();
|
||||||
chat->moveToThread(qGuiApp->thread());
|
chat->moveToThread(qGuiApp->thread());
|
||||||
if (!chat->deserialize(in, version)) {
|
bool ok = chat->deserialize(in, version);
|
||||||
|
if (!ok) {
|
||||||
qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName();
|
qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName();
|
||||||
|
} else if (!in.atEnd()) {
|
||||||
|
qWarning().nospace() << "error loading chat from " << file.fileName() << ": extra data at end of file";
|
||||||
} else {
|
} else {
|
||||||
emit chatRestored(chat);
|
emit chatRestored(chat.release());
|
||||||
}
|
}
|
||||||
if (f.oldFile)
|
if (f.oldFile)
|
||||||
file.remove(); // No longer storing in this directory
|
file.remove(); // No longer storing in this directory
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -13,20 +13,24 @@
|
|||||||
#include <QObject>
|
#include <QObject>
|
||||||
#include <QPointer>
|
#include <QPointer>
|
||||||
#include <QString>
|
#include <QString>
|
||||||
|
#include <QStringList> // IWYU pragma: keep
|
||||||
|
#include <QStringView>
|
||||||
#include <QThread>
|
#include <QThread>
|
||||||
#include <QVariantMap>
|
#include <QVariantMap> // IWYU pragma: keep
|
||||||
#include <QtGlobal>
|
#include <QtGlobal>
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <span>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <variant>
|
||||||
|
|
||||||
using namespace Qt::Literals::StringLiterals;
|
using namespace Qt::Literals::StringLiterals;
|
||||||
|
|
||||||
class QDataStream;
|
class QDataStream;
|
||||||
|
struct ChatItem;
|
||||||
|
|
||||||
// NOTE: values serialized to disk, do not change or reuse
|
// NOTE: values serialized to disk, do not change or reuse
|
||||||
enum class LLModelTypeV0 { // chat versions 2-5
|
enum class LLModelTypeV0 { // chat versions 2-5
|
||||||
@ -142,7 +146,6 @@ class Chat;
|
|||||||
class ChatLLM : public QObject
|
class ChatLLM : public QObject
|
||||||
{
|
{
|
||||||
Q_OBJECT
|
Q_OBJECT
|
||||||
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
|
|
||||||
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
|
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
|
||||||
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
|
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
|
||||||
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
|
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
|
||||||
@ -150,12 +153,14 @@ public:
|
|||||||
ChatLLM(Chat *parent, bool isServer = false);
|
ChatLLM(Chat *parent, bool isServer = false);
|
||||||
virtual ~ChatLLM();
|
virtual ~ChatLLM();
|
||||||
|
|
||||||
void destroy();
|
|
||||||
static void destroyStore();
|
static void destroyStore();
|
||||||
|
static std::optional<std::string> checkJinjaTemplateError(const std::string &source);
|
||||||
|
|
||||||
|
void destroy();
|
||||||
bool isModelLoaded() const;
|
bool isModelLoaded() const;
|
||||||
void regenerateResponse();
|
void regenerateResponse(int index);
|
||||||
void resetResponse();
|
// used to implement edit functionality
|
||||||
void resetContext();
|
std::optional<QString> popPrompt(int index);
|
||||||
|
|
||||||
void stopGenerating() { m_stopGenerating = true; }
|
void stopGenerating() { m_stopGenerating = true; }
|
||||||
|
|
||||||
@ -165,13 +170,9 @@ public:
|
|||||||
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
|
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
|
||||||
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
|
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
|
||||||
|
|
||||||
QString response(bool trim = true) const;
|
|
||||||
|
|
||||||
ModelInfo modelInfo() const;
|
ModelInfo modelInfo() const;
|
||||||
void setModelInfo(const ModelInfo &info);
|
void setModelInfo(const ModelInfo &info);
|
||||||
|
|
||||||
bool restoringFromText() const { return m_restoringFromText; }
|
|
||||||
|
|
||||||
void acquireModel();
|
void acquireModel();
|
||||||
void resetModel();
|
void resetModel();
|
||||||
|
|
||||||
@ -196,13 +197,11 @@ public:
|
|||||||
return m_llModelInfo.fallbackReason.value_or(u""_s);
|
return m_llModelInfo.fallbackReason.value_or(u""_s);
|
||||||
}
|
}
|
||||||
|
|
||||||
QString generatedName() const { return QString::fromStdString(m_nameResponse); }
|
bool serialize(QDataStream &stream, int version);
|
||||||
|
bool deserialize(QDataStream &stream, int version);
|
||||||
bool serialize(QDataStream &stream, int version, bool serializeKV);
|
|
||||||
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
|
|
||||||
|
|
||||||
public Q_SLOTS:
|
public Q_SLOTS:
|
||||||
bool prompt(const QList<QString> &collectionList, const QString &prompt);
|
void prompt(const QStringList &enabledCollections);
|
||||||
bool loadDefaultModel();
|
bool loadDefaultModel();
|
||||||
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
|
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
|
||||||
bool loadModel(const ModelInfo &modelInfo);
|
bool loadModel(const ModelInfo &modelInfo);
|
||||||
@ -210,22 +209,19 @@ public Q_SLOTS:
|
|||||||
void unloadModel();
|
void unloadModel();
|
||||||
void reloadModel();
|
void reloadModel();
|
||||||
void generateName();
|
void generateName();
|
||||||
void generateQuestions(qint64 elapsed);
|
|
||||||
void handleChatIdChanged(const QString &id);
|
void handleChatIdChanged(const QString &id);
|
||||||
void handleShouldBeLoadedChanged();
|
void handleShouldBeLoadedChanged();
|
||||||
void handleThreadStarted();
|
void handleThreadStarted();
|
||||||
void handleForceMetalChanged(bool forceMetal);
|
void handleForceMetalChanged(bool forceMetal);
|
||||||
void handleDeviceChanged();
|
void handleDeviceChanged();
|
||||||
void processSystemPrompt();
|
|
||||||
void processRestoreStateFromText();
|
|
||||||
|
|
||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void restoringFromTextChanged();
|
|
||||||
void loadedModelInfoChanged();
|
void loadedModelInfoChanged();
|
||||||
void modelLoadingPercentageChanged(float);
|
void modelLoadingPercentageChanged(float);
|
||||||
void modelLoadingError(const QString &error);
|
void modelLoadingError(const QString &error);
|
||||||
void modelLoadingWarning(const QString &warning);
|
void modelLoadingWarning(const QString &warning);
|
||||||
void responseChanged(const QString &response);
|
void responseChanged(const QString &response);
|
||||||
|
void responseFailed(const QString &error);
|
||||||
void promptProcessing();
|
void promptProcessing();
|
||||||
void generatingQuestions();
|
void generatingQuestions();
|
||||||
void responseStopped(qint64 promptResponseMs);
|
void responseStopped(qint64 promptResponseMs);
|
||||||
@ -244,58 +240,50 @@ Q_SIGNALS:
|
|||||||
void modelInfoChanged(const ModelInfo &modelInfo);
|
void modelInfoChanged(const ModelInfo &modelInfo);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
struct PromptResult {
|
||||||
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
QByteArray response; // raw UTF-8
|
||||||
int32_t repeat_penalty_tokens, std::optional<QString> fakeReply = {});
|
int promptTokens; // note: counts *entire* history, even if cached
|
||||||
bool handlePrompt(int32_t token);
|
int responseTokens;
|
||||||
bool handleResponse(int32_t token, const std::string &response);
|
};
|
||||||
bool handleNamePrompt(int32_t token);
|
|
||||||
bool handleNameResponse(int32_t token, const std::string &response);
|
|
||||||
bool handleSystemPrompt(int32_t token);
|
|
||||||
bool handleSystemResponse(int32_t token, const std::string &response);
|
|
||||||
bool handleRestoreStateFromTextPrompt(int32_t token);
|
|
||||||
bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response);
|
|
||||||
bool handleQuestionPrompt(int32_t token);
|
|
||||||
bool handleQuestionResponse(int32_t token, const std::string &response);
|
|
||||||
void saveState();
|
|
||||||
void restoreState();
|
|
||||||
|
|
||||||
protected:
|
struct ChatPromptResult : PromptResult {
|
||||||
LLModel::PromptContext m_ctx;
|
QList<ResultInfo> databaseResults;
|
||||||
quint32 m_promptTokens;
|
};
|
||||||
quint32 m_promptResponseTokens;
|
|
||||||
|
ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx);
|
||||||
|
// passing a string_view directly skips templating and uses the raw string
|
||||||
|
PromptResult promptInternal(const std::variant<std::span<const ChatItem>, std::string_view> &prompt,
|
||||||
|
const LLModel::PromptContext &ctx,
|
||||||
|
bool usedLocalDocs);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
|
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
|
||||||
|
|
||||||
|
std::vector<ChatItem> forkConversation(const QString &prompt) const;
|
||||||
|
|
||||||
|
// Applies the Jinja template. Query mode returns only the last message without special tokens.
|
||||||
|
// Returns a (# of messages, rendered prompt) pair.
|
||||||
|
std::string applyJinjaTemplate(std::span<const ChatItem> items) const;
|
||||||
|
|
||||||
|
void generateQuestions(qint64 elapsed);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
QPointer<ChatModel> m_chatModel;
|
||||||
|
|
||||||
|
private:
|
||||||
const Chat *m_chat;
|
const Chat *m_chat;
|
||||||
std::string m_response;
|
|
||||||
std::string m_trimmedResponse;
|
|
||||||
std::string m_nameResponse;
|
|
||||||
QString m_questionResponse;
|
|
||||||
LLModelInfo m_llModelInfo;
|
LLModelInfo m_llModelInfo;
|
||||||
LLModelTypeV1 m_llModelType = LLModelTypeV1::NONE;
|
LLModelTypeV1 m_llModelType = LLModelTypeV1::NONE;
|
||||||
ModelInfo m_modelInfo;
|
ModelInfo m_modelInfo;
|
||||||
TokenTimer *m_timer;
|
TokenTimer *m_timer;
|
||||||
QByteArray m_state;
|
|
||||||
std::vector<LLModel::Token> m_stateInputTokens;
|
|
||||||
int32_t m_stateContextLength = -1;
|
|
||||||
QThread m_llmThread;
|
QThread m_llmThread;
|
||||||
std::atomic<bool> m_stopGenerating;
|
std::atomic<bool> m_stopGenerating;
|
||||||
std::atomic<bool> m_shouldBeLoaded;
|
std::atomic<bool> m_shouldBeLoaded;
|
||||||
std::atomic<bool> m_restoringFromText; // status indication
|
|
||||||
std::atomic<bool> m_forceUnloadModel;
|
std::atomic<bool> m_forceUnloadModel;
|
||||||
std::atomic<bool> m_markedForDeletion;
|
std::atomic<bool> m_markedForDeletion;
|
||||||
bool m_isServer;
|
bool m_isServer;
|
||||||
bool m_forceMetal;
|
bool m_forceMetal;
|
||||||
bool m_reloadingToChangeVariant;
|
bool m_reloadingToChangeVariant;
|
||||||
bool m_processedSystemPrompt;
|
|
||||||
bool m_restoreStateFromText;
|
|
||||||
// m_pristineLoadedState is set if saveSate is unnecessary, either because:
|
|
||||||
// - an unload was queued during LLModel::restoreState()
|
|
||||||
// - the chat will be restored from text and hasn't been interacted with yet
|
|
||||||
bool m_pristineLoadedState = false;
|
|
||||||
QPointer<ChatModel> m_chatModel;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // CHATLLM_H
|
#endif // CHATLLM_H
|
||||||
|
@ -2,8 +2,11 @@
|
|||||||
#define CHATMODEL_H
|
#define CHATMODEL_H
|
||||||
|
|
||||||
#include "database.h"
|
#include "database.h"
|
||||||
|
#include "utils.h"
|
||||||
#include "xlsxtomd.h"
|
#include "xlsxtomd.h"
|
||||||
|
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
#include <QAbstractListModel>
|
#include <QAbstractListModel>
|
||||||
#include <QBuffer>
|
#include <QBuffer>
|
||||||
#include <QByteArray>
|
#include <QByteArray>
|
||||||
@ -18,6 +21,15 @@
|
|||||||
#include <Qt>
|
#include <Qt>
|
||||||
#include <QtGlobal>
|
#include <QtGlobal>
|
||||||
|
|
||||||
|
#include <iterator>
|
||||||
|
#include <ranges>
|
||||||
|
#include <span>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
using namespace Qt::Literals::StringLiterals;
|
||||||
|
namespace ranges = std::ranges;
|
||||||
|
|
||||||
|
|
||||||
struct PromptAttachment {
|
struct PromptAttachment {
|
||||||
Q_GADGET
|
Q_GADGET
|
||||||
Q_PROPERTY(QUrl url MEMBER url)
|
Q_PROPERTY(QUrl url MEMBER url)
|
||||||
@ -60,66 +72,145 @@ Q_DECLARE_METATYPE(PromptAttachment)
|
|||||||
struct ChatItem
|
struct ChatItem
|
||||||
{
|
{
|
||||||
Q_GADGET
|
Q_GADGET
|
||||||
Q_PROPERTY(QString name MEMBER name)
|
Q_PROPERTY(QString name MEMBER name )
|
||||||
Q_PROPERTY(QString value MEMBER value)
|
Q_PROPERTY(QString value MEMBER value)
|
||||||
Q_PROPERTY(QString newResponse MEMBER newResponse)
|
|
||||||
Q_PROPERTY(bool currentResponse MEMBER currentResponse)
|
// prompts
|
||||||
Q_PROPERTY(bool stopped MEMBER stopped)
|
|
||||||
Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState)
|
|
||||||
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)
|
|
||||||
Q_PROPERTY(QList<ResultInfo> sources MEMBER sources)
|
|
||||||
Q_PROPERTY(QList<ResultInfo> consolidatedSources MEMBER consolidatedSources)
|
|
||||||
Q_PROPERTY(QList<PromptAttachment> promptAttachments MEMBER promptAttachments)
|
Q_PROPERTY(QList<PromptAttachment> promptAttachments MEMBER promptAttachments)
|
||||||
Q_PROPERTY(QString promptPlusAttachments READ promptPlusAttachments)
|
Q_PROPERTY(QString bakedPrompt READ bakedPrompt )
|
||||||
|
|
||||||
|
// responses
|
||||||
|
Q_PROPERTY(bool isCurrentResponse MEMBER isCurrentResponse)
|
||||||
|
Q_PROPERTY(bool isError MEMBER isError )
|
||||||
|
|
||||||
|
// responses (DataLake)
|
||||||
|
Q_PROPERTY(QString newResponse MEMBER newResponse )
|
||||||
|
Q_PROPERTY(bool stopped MEMBER stopped )
|
||||||
|
Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState )
|
||||||
|
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)
|
||||||
|
|
||||||
public:
|
public:
|
||||||
QString promptPlusAttachments() const
|
enum class Type { System, Prompt, Response };
|
||||||
{
|
|
||||||
QStringList attachedContexts;
|
|
||||||
for (auto attached : promptAttachments)
|
|
||||||
attachedContexts << attached.processedContent();
|
|
||||||
|
|
||||||
QString promptPlus = value;
|
// tags for constructing ChatItems
|
||||||
if (!attachedContexts.isEmpty())
|
struct prompt_tag_t { explicit prompt_tag_t() = default; };
|
||||||
promptPlus = attachedContexts.join("\n\n") + "\n\n" + value;
|
static inline constexpr prompt_tag_t prompt_tag = prompt_tag_t();
|
||||||
return promptPlus;
|
struct response_tag_t { explicit response_tag_t() = default; };
|
||||||
|
static inline constexpr response_tag_t response_tag = response_tag_t();
|
||||||
|
struct system_tag_t { explicit system_tag_t() = default; };
|
||||||
|
static inline constexpr system_tag_t system_tag = system_tag_t();
|
||||||
|
|
||||||
|
// FIXME(jared): This should not be necessary. QML should see null or undefined if it
|
||||||
|
// tries to access something invalid.
|
||||||
|
ChatItem() = default;
|
||||||
|
|
||||||
|
// NOTE: system messages are currently never stored in the model or serialized
|
||||||
|
ChatItem(system_tag_t, const QString &value)
|
||||||
|
: name(u"System: "_s), value(value) {}
|
||||||
|
|
||||||
|
ChatItem(prompt_tag_t, const QString &value, const QList<PromptAttachment> &attachments = {})
|
||||||
|
: name(u"Prompt: "_s), value(value), promptAttachments(attachments) {}
|
||||||
|
|
||||||
|
ChatItem(response_tag_t, bool isCurrentResponse = true)
|
||||||
|
: name(u"Response: "_s), isCurrentResponse(isCurrentResponse) {}
|
||||||
|
|
||||||
|
Type type() const
|
||||||
|
{
|
||||||
|
if (name == u"System: "_s)
|
||||||
|
return Type::System;
|
||||||
|
if (name == u"Prompt: "_s)
|
||||||
|
return Type::Prompt;
|
||||||
|
if (name == u"Response: "_s)
|
||||||
|
return Type::Response;
|
||||||
|
throw std::invalid_argument(fmt::format("Chat item has unknown label: {:?}", name));
|
||||||
|
}
|
||||||
|
|
||||||
|
// used with version 0 Jinja templates
|
||||||
|
QString bakedPrompt() const
|
||||||
|
{
|
||||||
|
if (type() != Type::Prompt)
|
||||||
|
throw std::logic_error("bakedPrompt() called on non-prompt item");
|
||||||
|
QStringList parts;
|
||||||
|
if (!sources.isEmpty()) {
|
||||||
|
parts << u"### Context:\n"_s;
|
||||||
|
for (auto &source : std::as_const(sources))
|
||||||
|
parts << u"Collection: "_s << source.collection
|
||||||
|
<< u"\nPath: "_s << source.path
|
||||||
|
<< u"\nExcerpt: "_s << source.text << u"\n\n"_s;
|
||||||
|
}
|
||||||
|
for (auto &attached : std::as_const(promptAttachments))
|
||||||
|
parts << attached.processedContent() << u"\n\n"_s;
|
||||||
|
parts << value;
|
||||||
|
return parts.join(QString());
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Maybe we should include the model name here as well as timestamp?
|
// TODO: Maybe we should include the model name here as well as timestamp?
|
||||||
QString name;
|
QString name;
|
||||||
QString value;
|
QString value;
|
||||||
QString newResponse;
|
|
||||||
QList<ResultInfo> sources;
|
// prompts
|
||||||
QList<ResultInfo> consolidatedSources;
|
QList<ResultInfo> sources;
|
||||||
|
QList<ResultInfo> consolidatedSources;
|
||||||
QList<PromptAttachment> promptAttachments;
|
QList<PromptAttachment> promptAttachments;
|
||||||
bool currentResponse = false;
|
|
||||||
bool stopped = false;
|
// responses
|
||||||
bool thumbsUpState = false;
|
bool isCurrentResponse = false;
|
||||||
bool thumbsDownState = false;
|
bool isError = false;
|
||||||
|
|
||||||
|
// responses (DataLake)
|
||||||
|
QString newResponse;
|
||||||
|
bool stopped = false;
|
||||||
|
bool thumbsUpState = false;
|
||||||
|
bool thumbsDownState = false;
|
||||||
};
|
};
|
||||||
Q_DECLARE_METATYPE(ChatItem)
|
Q_DECLARE_METATYPE(ChatItem)
|
||||||
|
|
||||||
using ChatModelIterator = QList<ChatItem>::const_iterator;
|
class ChatModelAccessor : public ranges::subrange<QList<ChatItem>::const_iterator> {
|
||||||
|
private:
|
||||||
|
using Super = ranges::subrange<QList<ChatItem>::const_iterator>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
template <typename... T>
|
||||||
|
ChatModelAccessor(QMutex &mutex, T &&...args)
|
||||||
|
: Super(std::forward<T>(args)...), m_lock(&mutex) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
QMutexLocker<QMutex> m_lock;
|
||||||
|
};
|
||||||
|
|
||||||
class ChatModel : public QAbstractListModel
|
class ChatModel : public QAbstractListModel
|
||||||
{
|
{
|
||||||
Q_OBJECT
|
Q_OBJECT
|
||||||
Q_PROPERTY(int count READ count NOTIFY countChanged)
|
Q_PROPERTY(int count READ count NOTIFY countChanged)
|
||||||
|
Q_PROPERTY(bool hasError READ hasError NOTIFY hasErrorChanged)
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit ChatModel(QObject *parent = nullptr) : QAbstractListModel(parent) {}
|
explicit ChatModel(QObject *parent = nullptr)
|
||||||
|
: QAbstractListModel(parent) {}
|
||||||
|
|
||||||
|
// FIXME(jared): can't this start at Qt::UserRole (no +1)?
|
||||||
enum Roles {
|
enum Roles {
|
||||||
NameRole = Qt::UserRole + 1,
|
NameRole = Qt::UserRole + 1,
|
||||||
ValueRole,
|
ValueRole,
|
||||||
|
|
||||||
|
// prompts and responses
|
||||||
|
PeerRole,
|
||||||
|
|
||||||
|
// prompts
|
||||||
|
PromptAttachmentsRole,
|
||||||
|
|
||||||
|
// responses
|
||||||
|
// NOTE: sources are stored on the *prompts*, but in the model, they are only on the *responses*!
|
||||||
|
SourcesRole,
|
||||||
|
ConsolidatedSourcesRole,
|
||||||
|
IsCurrentResponseRole,
|
||||||
|
IsErrorRole,
|
||||||
|
|
||||||
|
// responses (DataLake)
|
||||||
NewResponseRole,
|
NewResponseRole,
|
||||||
CurrentResponseRole,
|
|
||||||
StoppedRole,
|
StoppedRole,
|
||||||
ThumbsUpStateRole,
|
ThumbsUpStateRole,
|
||||||
ThumbsDownStateRole,
|
ThumbsDownStateRole,
|
||||||
SourcesRole,
|
|
||||||
ConsolidatedSourcesRole,
|
|
||||||
PromptAttachmentsRole
|
|
||||||
};
|
};
|
||||||
|
|
||||||
int rowCount(const QModelIndex &parent = QModelIndex()) const override
|
int rowCount(const QModelIndex &parent = QModelIndex()) const override
|
||||||
@ -129,34 +220,96 @@ public:
|
|||||||
return m_chatItems.size();
|
return m_chatItems.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* a "peer" is a bidirectional 1:1 link between a prompt and the response that would cite its LocalDocs
|
||||||
|
* sources. Return std::nullopt if there is none, which is possible for e.g. server chats. */
|
||||||
|
auto getPeerUnlocked(QList<ChatItem>::const_iterator item) const
|
||||||
|
-> std::optional<QList<ChatItem>::const_iterator>
|
||||||
|
{
|
||||||
|
switch (item->type()) {
|
||||||
|
using enum ChatItem::Type;
|
||||||
|
case Prompt:
|
||||||
|
{
|
||||||
|
auto peer = std::next(item);
|
||||||
|
if (peer < m_chatItems.cend() && peer->type() == Response)
|
||||||
|
return peer;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Response:
|
||||||
|
{
|
||||||
|
if (item > m_chatItems.cbegin()) {
|
||||||
|
if (auto peer = std::prev(item); peer->type() == Prompt)
|
||||||
|
return peer;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("getPeer() called on item that is not a prompt or response");
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto getPeerUnlocked(int index) const -> std::optional<int>
|
||||||
|
{
|
||||||
|
return getPeerUnlocked(m_chatItems.cbegin() + index)
|
||||||
|
.transform([&](auto &&i) { return i - m_chatItems.cbegin(); } );
|
||||||
|
}
|
||||||
|
|
||||||
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override
|
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override
|
||||||
{
|
{
|
||||||
QMutexLocker locker(&m_mutex);
|
QMutexLocker locker(&m_mutex);
|
||||||
if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size())
|
if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size())
|
||||||
return QVariant();
|
return QVariant();
|
||||||
|
|
||||||
const ChatItem &item = m_chatItems.at(index.row());
|
auto item = m_chatItems.cbegin() + index.row();
|
||||||
switch (role) {
|
switch (role) {
|
||||||
case NameRole:
|
case NameRole:
|
||||||
return item.name;
|
return item->name;
|
||||||
case ValueRole:
|
case ValueRole:
|
||||||
return item.value;
|
return item->value;
|
||||||
case NewResponseRole:
|
case PeerRole:
|
||||||
return item.newResponse;
|
switch (item->type()) {
|
||||||
case CurrentResponseRole:
|
using enum ChatItem::Type;
|
||||||
return item.currentResponse;
|
case Prompt:
|
||||||
case StoppedRole:
|
case Response:
|
||||||
return item.stopped;
|
{
|
||||||
case ThumbsUpStateRole:
|
auto peer = getPeerUnlocked(item);
|
||||||
return item.thumbsUpState;
|
return peer ? QVariant::fromValue(**peer) : QVariant::fromValue(nullptr);
|
||||||
case ThumbsDownStateRole:
|
}
|
||||||
return item.thumbsDownState;
|
default:
|
||||||
case SourcesRole:
|
return QVariant();
|
||||||
return QVariant::fromValue(item.sources);
|
}
|
||||||
case ConsolidatedSourcesRole:
|
|
||||||
return QVariant::fromValue(item.consolidatedSources);
|
|
||||||
case PromptAttachmentsRole:
|
case PromptAttachmentsRole:
|
||||||
return QVariant::fromValue(item.promptAttachments);
|
return QVariant::fromValue(item->promptAttachments);
|
||||||
|
case SourcesRole:
|
||||||
|
{
|
||||||
|
QList<ResultInfo> data;
|
||||||
|
if (item->type() == ChatItem::Type::Response) {
|
||||||
|
if (auto prompt = getPeerUnlocked(item))
|
||||||
|
data = (*prompt)->consolidatedSources;
|
||||||
|
}
|
||||||
|
return QVariant::fromValue(data);
|
||||||
|
}
|
||||||
|
case ConsolidatedSourcesRole:
|
||||||
|
{
|
||||||
|
QList<ResultInfo> data;
|
||||||
|
if (item->type() == ChatItem::Type::Response) {
|
||||||
|
if (auto prompt = getPeerUnlocked(item))
|
||||||
|
data = (*prompt)->sources;
|
||||||
|
}
|
||||||
|
return QVariant::fromValue(data);
|
||||||
|
}
|
||||||
|
case IsCurrentResponseRole:
|
||||||
|
return item->isCurrentResponse;
|
||||||
|
case NewResponseRole:
|
||||||
|
return item->newResponse;
|
||||||
|
case StoppedRole:
|
||||||
|
return item->stopped;
|
||||||
|
case ThumbsUpStateRole:
|
||||||
|
return item->thumbsUpState;
|
||||||
|
case ThumbsDownStateRole:
|
||||||
|
return item->thumbsDownState;
|
||||||
|
case IsErrorRole:
|
||||||
|
return item->type() == ChatItem::Type::Response && item->isError;
|
||||||
}
|
}
|
||||||
|
|
||||||
return QVariant();
|
return QVariant();
|
||||||
@ -164,54 +317,126 @@ public:
|
|||||||
|
|
||||||
QHash<int, QByteArray> roleNames() const override
|
QHash<int, QByteArray> roleNames() const override
|
||||||
{
|
{
|
||||||
QHash<int, QByteArray> roles;
|
return {
|
||||||
roles[NameRole] = "name";
|
{ NameRole, "name" },
|
||||||
roles[ValueRole] = "value";
|
{ ValueRole, "value" },
|
||||||
roles[NewResponseRole] = "newResponse";
|
{ PeerRole, "peer" },
|
||||||
roles[CurrentResponseRole] = "currentResponse";
|
{ PromptAttachmentsRole, "promptAttachments" },
|
||||||
roles[StoppedRole] = "stopped";
|
{ SourcesRole, "sources" },
|
||||||
roles[ThumbsUpStateRole] = "thumbsUpState";
|
{ ConsolidatedSourcesRole, "consolidatedSources" },
|
||||||
roles[ThumbsDownStateRole] = "thumbsDownState";
|
{ IsCurrentResponseRole, "isCurrentResponse" },
|
||||||
roles[SourcesRole] = "sources";
|
{ IsErrorRole, "isError" },
|
||||||
roles[ConsolidatedSourcesRole] = "consolidatedSources";
|
{ NewResponseRole, "newResponse" },
|
||||||
roles[PromptAttachmentsRole] = "promptAttachments";
|
{ StoppedRole, "stopped" },
|
||||||
return roles;
|
{ ThumbsUpStateRole, "thumbsUpState" },
|
||||||
|
{ ThumbsDownStateRole, "thumbsDownState" },
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
void appendPrompt(const QString &name, const QString &value, const QList<PromptAttachment> &attachments)
|
void appendPrompt(const QString &value, const QList<PromptAttachment> &attachments = {})
|
||||||
{
|
{
|
||||||
ChatItem item;
|
qsizetype count;
|
||||||
item.name = name;
|
{
|
||||||
item.value = value;
|
QMutexLocker locker(&m_mutex);
|
||||||
item.promptAttachments << attachments;
|
if (hasErrorUnlocked())
|
||||||
|
throw std::logic_error("cannot append to a failed chat");
|
||||||
|
count = m_chatItems.count();
|
||||||
|
}
|
||||||
|
|
||||||
m_mutex.lock();
|
|
||||||
const int count = m_chatItems.count();
|
|
||||||
m_mutex.unlock();
|
|
||||||
beginInsertRows(QModelIndex(), count, count);
|
beginInsertRows(QModelIndex(), count, count);
|
||||||
{
|
{
|
||||||
QMutexLocker locker(&m_mutex);
|
QMutexLocker locker(&m_mutex);
|
||||||
m_chatItems.append(item);
|
m_chatItems.emplace_back(ChatItem::prompt_tag, value, attachments);
|
||||||
}
|
}
|
||||||
endInsertRows();
|
endInsertRows();
|
||||||
emit countChanged();
|
emit countChanged();
|
||||||
}
|
}
|
||||||
|
|
||||||
void appendResponse(const QString &name)
|
void appendResponse(int promptIndex)
|
||||||
{
|
{
|
||||||
m_mutex.lock();
|
qsizetype count;
|
||||||
const int count = m_chatItems.count();
|
{
|
||||||
m_mutex.unlock();
|
QMutexLocker locker(&m_mutex);
|
||||||
ChatItem item;
|
if (hasErrorUnlocked())
|
||||||
item.name = name;
|
throw std::logic_error("cannot append to a failed chat");
|
||||||
item.currentResponse = true;
|
count = m_chatItems.count();
|
||||||
|
}
|
||||||
|
|
||||||
beginInsertRows(QModelIndex(), count, count);
|
beginInsertRows(QModelIndex(), count, count);
|
||||||
{
|
{
|
||||||
QMutexLocker locker(&m_mutex);
|
QMutexLocker locker(&m_mutex);
|
||||||
m_chatItems.append(item);
|
if (promptIndex >= 0) {
|
||||||
|
if (promptIndex >= m_chatItems.size())
|
||||||
|
throw std::out_of_range(fmt::format("index {} is out of range", promptIndex));
|
||||||
|
auto &promptItem = m_chatItems[promptIndex];
|
||||||
|
if (promptItem.type() != ChatItem::Type::Prompt)
|
||||||
|
throw std::invalid_argument(fmt::format("item at index {} is not a prompt", promptIndex));
|
||||||
|
}
|
||||||
|
m_chatItems.emplace_back(ChatItem::response_tag, promptIndex);
|
||||||
}
|
}
|
||||||
endInsertRows();
|
endInsertRows();
|
||||||
emit countChanged();
|
emit countChanged();
|
||||||
|
if (promptIndex >= 0)
|
||||||
|
emit dataChanged(createIndex(promptIndex, 0), createIndex(promptIndex, 0), {PeerRole});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used by Server to append a new conversation to the chat log.
|
||||||
|
void appendResponseWithHistory(std::span<const ChatItem> history)
|
||||||
|
{
|
||||||
|
if (history.empty())
|
||||||
|
throw std::invalid_argument("at least one message is required");
|
||||||
|
|
||||||
|
m_mutex.lock();
|
||||||
|
qsizetype startIndex = m_chatItems.count();
|
||||||
|
m_mutex.unlock();
|
||||||
|
|
||||||
|
qsizetype nNewItems = history.size() + 1;
|
||||||
|
qsizetype endIndex = startIndex + nNewItems;
|
||||||
|
beginInsertRows(QModelIndex(), startIndex, endIndex - 1 /*inclusive*/);
|
||||||
|
bool hadError;
|
||||||
|
int promptIndex;
|
||||||
|
{
|
||||||
|
QMutexLocker locker(&m_mutex);
|
||||||
|
hadError = hasErrorUnlocked();
|
||||||
|
m_chatItems.reserve(m_chatItems.count() + nNewItems);
|
||||||
|
for (auto &item : history)
|
||||||
|
m_chatItems << item;
|
||||||
|
m_chatItems.emplace_back(ChatItem::response_tag);
|
||||||
|
}
|
||||||
|
endInsertRows();
|
||||||
|
emit countChanged();
|
||||||
|
// Server can add messages when there is an error because each call is a new conversation
|
||||||
|
if (hadError)
|
||||||
|
emit hasErrorChanged(false);
|
||||||
|
if (promptIndex >= 0)
|
||||||
|
emit dataChanged(createIndex(promptIndex, 0), createIndex(promptIndex, 0), {PeerRole});
|
||||||
|
}
|
||||||
|
|
||||||
|
void truncate(qsizetype size)
|
||||||
|
{
|
||||||
|
qsizetype oldSize;
|
||||||
|
{
|
||||||
|
QMutexLocker locker(&m_mutex);
|
||||||
|
if (size >= (oldSize = m_chatItems.size()))
|
||||||
|
return;
|
||||||
|
if (size && m_chatItems.at(size - 1).type() != ChatItem::Type::Response)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
fmt::format("chat model truncated to {} items would not end in a response", size)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool oldHasError;
|
||||||
|
beginRemoveRows(QModelIndex(), size, oldSize - 1 /*inclusive*/);
|
||||||
|
{
|
||||||
|
QMutexLocker locker(&m_mutex);
|
||||||
|
oldHasError = hasErrorUnlocked();
|
||||||
|
Q_ASSERT(size < m_chatItems.size());
|
||||||
|
m_chatItems.resize(size);
|
||||||
|
}
|
||||||
|
endRemoveRows();
|
||||||
|
emit countChanged();
|
||||||
|
if (oldHasError)
|
||||||
|
emit hasErrorChanged(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
Q_INVOKABLE void clear()
|
Q_INVOKABLE void clear()
|
||||||
@ -221,13 +446,17 @@ public:
|
|||||||
if (m_chatItems.isEmpty()) return;
|
if (m_chatItems.isEmpty()) return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool oldHasError;
|
||||||
beginResetModel();
|
beginResetModel();
|
||||||
{
|
{
|
||||||
QMutexLocker locker(&m_mutex);
|
QMutexLocker locker(&m_mutex);
|
||||||
|
oldHasError = hasErrorUnlocked();
|
||||||
m_chatItems.clear();
|
m_chatItems.clear();
|
||||||
}
|
}
|
||||||
endResetModel();
|
endResetModel();
|
||||||
emit countChanged();
|
emit countChanged();
|
||||||
|
if (oldHasError)
|
||||||
|
emit hasErrorChanged(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
Q_INVOKABLE ChatItem get(int index)
|
Q_INVOKABLE ChatItem get(int index)
|
||||||
@ -245,13 +474,13 @@ public:
|
|||||||
if (index < 0 || index >= m_chatItems.size()) return;
|
if (index < 0 || index >= m_chatItems.size()) return;
|
||||||
|
|
||||||
ChatItem &item = m_chatItems[index];
|
ChatItem &item = m_chatItems[index];
|
||||||
if (item.currentResponse != b) {
|
if (item.isCurrentResponse != b) {
|
||||||
item.currentResponse = b;
|
item.isCurrentResponse = b;
|
||||||
changed = true;
|
changed = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole});
|
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsCurrentResponseRole});
|
||||||
}
|
}
|
||||||
|
|
||||||
Q_INVOKABLE void updateStopped(int index, bool b)
|
Q_INVOKABLE void updateStopped(int index, bool b)
|
||||||
@ -304,16 +533,23 @@ public:
|
|||||||
|
|
||||||
Q_INVOKABLE void updateSources(int index, const QList<ResultInfo> &sources)
|
Q_INVOKABLE void updateSources(int index, const QList<ResultInfo> &sources)
|
||||||
{
|
{
|
||||||
|
int responseIndex = -1;
|
||||||
{
|
{
|
||||||
QMutexLocker locker(&m_mutex);
|
QMutexLocker locker(&m_mutex);
|
||||||
if (index < 0 || index >= m_chatItems.size()) return;
|
if (index < 0 || index >= m_chatItems.size()) return;
|
||||||
|
|
||||||
ChatItem &item = m_chatItems[index];
|
auto promptItem = m_chatItems.begin() + index;
|
||||||
item.sources = sources;
|
if (promptItem->type() != ChatItem::Type::Prompt)
|
||||||
item.consolidatedSources = consolidateSources(sources);
|
throw std::invalid_argument(fmt::format("item at index {} is not a prompt", index));
|
||||||
|
if (auto peer = getPeerUnlocked(promptItem))
|
||||||
|
responseIndex = *peer - m_chatItems.cbegin();
|
||||||
|
promptItem->sources = sources;
|
||||||
|
promptItem->consolidatedSources = consolidateSources(sources);
|
||||||
|
}
|
||||||
|
if (responseIndex >= 0) {
|
||||||
|
emit dataChanged(createIndex(responseIndex, 0), createIndex(responseIndex, 0), {SourcesRole});
|
||||||
|
emit dataChanged(createIndex(responseIndex, 0), createIndex(responseIndex, 0), {ConsolidatedSourcesRole});
|
||||||
}
|
}
|
||||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole});
|
|
||||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Q_INVOKABLE void updateThumbsUpState(int index, bool b)
|
Q_INVOKABLE void updateThumbsUpState(int index, bool b)
|
||||||
@ -364,18 +600,56 @@ public:
|
|||||||
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
|
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
|
||||||
}
|
}
|
||||||
|
|
||||||
int count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); }
|
Q_INVOKABLE void setError(bool value = true)
|
||||||
|
{
|
||||||
|
qsizetype index;
|
||||||
|
{
|
||||||
|
QMutexLocker locker(&m_mutex);
|
||||||
|
|
||||||
ChatModelIterator begin() const { return m_chatItems.begin(); }
|
if (m_chatItems.isEmpty() || m_chatItems.cend()[-1].type() != ChatItem::Type::Response)
|
||||||
ChatModelIterator end() const { return m_chatItems.end(); }
|
throw std::logic_error("can only set error on a chat that ends with a response");
|
||||||
void lock() { m_mutex.lock(); }
|
|
||||||
void unlock() { m_mutex.unlock(); }
|
index = m_chatItems.count() - 1;
|
||||||
|
auto &last = m_chatItems.back();
|
||||||
|
if (last.isError == value)
|
||||||
|
return; // already set
|
||||||
|
last.isError = value;
|
||||||
|
}
|
||||||
|
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsErrorRole});
|
||||||
|
emit hasErrorChanged(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
qsizetype count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); }
|
||||||
|
|
||||||
|
ChatModelAccessor chatItems() const { return {m_mutex, std::as_const(m_chatItems)}; }
|
||||||
|
|
||||||
|
bool hasError() const { QMutexLocker locker(&m_mutex); return hasErrorUnlocked(); }
|
||||||
|
|
||||||
bool serialize(QDataStream &stream, int version) const
|
bool serialize(QDataStream &stream, int version) const
|
||||||
{
|
{
|
||||||
QMutexLocker locker(&m_mutex);
|
QMutexLocker locker(&m_mutex);
|
||||||
stream << int(m_chatItems.size());
|
stream << int(m_chatItems.size());
|
||||||
for (const auto &c : m_chatItems) {
|
for (auto itemIt = m_chatItems.cbegin(); itemIt < m_chatItems.cend(); ++itemIt) {
|
||||||
|
auto c = *itemIt; // NB: copies
|
||||||
|
if (version < 11) {
|
||||||
|
// move sources from their prompt to the next response
|
||||||
|
switch (c.type()) {
|
||||||
|
using enum ChatItem::Type;
|
||||||
|
case Prompt:
|
||||||
|
c.sources.clear();
|
||||||
|
c.consolidatedSources.clear();
|
||||||
|
break;
|
||||||
|
case Response:
|
||||||
|
// note: we drop sources for responseless prompts
|
||||||
|
if (auto peer = getPeerUnlocked(itemIt)) {
|
||||||
|
c.sources = (*peer)->sources;
|
||||||
|
c.consolidatedSources = (*peer)->consolidatedSources;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// FIXME: This 'id' should be eliminated the next time we bump serialization version.
|
// FIXME: This 'id' should be eliminated the next time we bump serialization version.
|
||||||
// (Jared) This was apparently never used.
|
// (Jared) This was apparently never used.
|
||||||
int id = 0;
|
int id = 0;
|
||||||
@ -383,10 +657,12 @@ public:
|
|||||||
stream << c.name;
|
stream << c.name;
|
||||||
stream << c.value;
|
stream << c.value;
|
||||||
stream << c.newResponse;
|
stream << c.newResponse;
|
||||||
stream << c.currentResponse;
|
stream << c.isCurrentResponse;
|
||||||
stream << c.stopped;
|
stream << c.stopped;
|
||||||
stream << c.thumbsUpState;
|
stream << c.thumbsUpState;
|
||||||
stream << c.thumbsDownState;
|
stream << c.thumbsDownState;
|
||||||
|
if (version >= 11 && c.type() == ChatItem::Type::Response)
|
||||||
|
stream << c.isError;
|
||||||
if (version >= 8) {
|
if (version >= 8) {
|
||||||
stream << c.sources.size();
|
stream << c.sources.size();
|
||||||
for (const ResultInfo &info : c.sources) {
|
for (const ResultInfo &info : c.sources) {
|
||||||
@ -452,14 +728,24 @@ public:
|
|||||||
|
|
||||||
bool deserialize(QDataStream &stream, int version)
|
bool deserialize(QDataStream &stream, int version)
|
||||||
{
|
{
|
||||||
|
clear(); // reset to known state
|
||||||
|
|
||||||
int size;
|
int size;
|
||||||
stream >> size;
|
stream >> size;
|
||||||
|
int lastPromptIndex = -1;
|
||||||
|
QList<ChatItem> chatItems;
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
ChatItem c;
|
ChatItem c;
|
||||||
// FIXME: see comment in serialization about id
|
// FIXME: see comment in serialization about id
|
||||||
int id;
|
int id;
|
||||||
stream >> id;
|
stream >> id;
|
||||||
stream >> c.name;
|
stream >> c.name;
|
||||||
|
try {
|
||||||
|
c.type(); // check name
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
qWarning() << "ChatModel ERROR:" << e.what();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
stream >> c.value;
|
stream >> c.value;
|
||||||
if (version < 10) {
|
if (version < 10) {
|
||||||
// This is deprecated and no longer used
|
// This is deprecated and no longer used
|
||||||
@ -467,10 +753,12 @@ public:
|
|||||||
stream >> prompt;
|
stream >> prompt;
|
||||||
}
|
}
|
||||||
stream >> c.newResponse;
|
stream >> c.newResponse;
|
||||||
stream >> c.currentResponse;
|
stream >> c.isCurrentResponse;
|
||||||
stream >> c.stopped;
|
stream >> c.stopped;
|
||||||
stream >> c.thumbsUpState;
|
stream >> c.thumbsUpState;
|
||||||
stream >> c.thumbsDownState;
|
stream >> c.thumbsDownState;
|
||||||
|
if (version >= 11 && c.type() == ChatItem::Type::Response)
|
||||||
|
stream >> c.isError;
|
||||||
if (version >= 8) {
|
if (version >= 8) {
|
||||||
qsizetype count;
|
qsizetype count;
|
||||||
stream >> count;
|
stream >> count;
|
||||||
@ -587,23 +875,53 @@ public:
|
|||||||
}
|
}
|
||||||
c.promptAttachments = attachments;
|
c.promptAttachments = attachments;
|
||||||
}
|
}
|
||||||
m_mutex.lock();
|
|
||||||
const int count = m_chatItems.size();
|
if (version < 11 && c.type() == ChatItem::Type::Response) {
|
||||||
m_mutex.unlock();
|
// move sources from the response to their last prompt
|
||||||
beginInsertRows(QModelIndex(), count, count);
|
if (lastPromptIndex >= 0) {
|
||||||
{
|
auto &prompt = chatItems[lastPromptIndex];
|
||||||
QMutexLocker locker(&m_mutex);
|
prompt.sources = std::move(c.sources );
|
||||||
m_chatItems.append(c);
|
prompt.consolidatedSources = std::move(c.consolidatedSources);
|
||||||
|
lastPromptIndex = -1;
|
||||||
|
} else {
|
||||||
|
// drop sources for promptless responses
|
||||||
|
c.sources.clear();
|
||||||
|
c.consolidatedSources.clear();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
endInsertRows();
|
|
||||||
|
chatItems << c;
|
||||||
|
if (c.type() == ChatItem::Type::Prompt)
|
||||||
|
lastPromptIndex = chatItems.size() - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasError;
|
||||||
|
beginInsertRows(QModelIndex(), 0, chatItems.size() - 1 /*inclusive*/);
|
||||||
|
{
|
||||||
|
QMutexLocker locker(&m_mutex);
|
||||||
|
m_chatItems = chatItems;
|
||||||
|
hasError = hasErrorUnlocked();
|
||||||
|
}
|
||||||
|
endInsertRows();
|
||||||
emit countChanged();
|
emit countChanged();
|
||||||
|
if (hasError)
|
||||||
|
emit hasErrorChanged(true);
|
||||||
return stream.status() == QDataStream::Ok;
|
return stream.status() == QDataStream::Ok;
|
||||||
}
|
}
|
||||||
|
|
||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void countChanged();
|
void countChanged();
|
||||||
void valueChanged(int index, const QString &value);
|
void valueChanged(int index, const QString &value);
|
||||||
|
void hasErrorChanged(bool value);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool hasErrorUnlocked() const
|
||||||
|
{
|
||||||
|
if (m_chatItems.isEmpty())
|
||||||
|
return false;
|
||||||
|
auto &last = m_chatItems.back();
|
||||||
|
return last.type() == ChatItem::Type::Response && last.isError;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mutable QMutex m_mutex;
|
mutable QMutex m_mutex;
|
||||||
|
111
gpt4all-chat/src/jinja_helpers.cpp
Normal file
111
gpt4all-chat/src/jinja_helpers.cpp
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
#include "jinja_helpers.h"
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
|
#include <QString>
|
||||||
|
#include <QUrl>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using namespace std::literals::string_view_literals;
|
||||||
|
|
||||||
|
|
||||||
|
JinjaResultInfo::~JinjaResultInfo() = default;
|
||||||
|
|
||||||
|
const JinjaFieldMap<ResultInfo> JinjaResultInfo::s_fields = {
|
||||||
|
{ "collection", [](auto &s) { return s.collection.toStdString(); } },
|
||||||
|
{ "path", [](auto &s) { return s.path .toStdString(); } },
|
||||||
|
{ "file", [](auto &s) { return s.file .toStdString(); } },
|
||||||
|
{ "title", [](auto &s) { return s.title .toStdString(); } },
|
||||||
|
{ "author", [](auto &s) { return s.author .toStdString(); } },
|
||||||
|
{ "date", [](auto &s) { return s.date .toStdString(); } },
|
||||||
|
{ "text", [](auto &s) { return s.text .toStdString(); } },
|
||||||
|
{ "page", [](auto &s) { return s.page; } },
|
||||||
|
{ "file_uri", [](auto &s) { return s.fileUri() .toStdString(); } },
|
||||||
|
};
|
||||||
|
|
||||||
|
JinjaPromptAttachment::~JinjaPromptAttachment() = default;
|
||||||
|
|
||||||
|
const JinjaFieldMap<PromptAttachment> JinjaPromptAttachment::s_fields = {
|
||||||
|
{ "url", [](auto &s) { return s.url.toString() .toStdString(); } },
|
||||||
|
{ "file", [](auto &s) { return s.file() .toStdString(); } },
|
||||||
|
{ "processed_content", [](auto &s) { return s.processedContent().toStdString(); } },
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::string> JinjaMessage::GetKeys() const
|
||||||
|
{
|
||||||
|
std::vector<std::string> result;
|
||||||
|
auto &keys = this->keys();
|
||||||
|
result.reserve(keys.size());
|
||||||
|
result.assign(keys.begin(), keys.end());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto JinjaMessage::keys() const -> const std::unordered_set<std::string_view> &
|
||||||
|
{
|
||||||
|
static const std::unordered_set<std::string_view> baseKeys
|
||||||
|
{ "role", "content" };
|
||||||
|
static const std::unordered_set<std::string_view> userKeys
|
||||||
|
{ "role", "content", "sources", "prompt_attachments" };
|
||||||
|
switch (m_item->type()) {
|
||||||
|
using enum ChatItem::Type;
|
||||||
|
case System:
|
||||||
|
case Response:
|
||||||
|
return baseKeys;
|
||||||
|
case Prompt:
|
||||||
|
return userKeys;
|
||||||
|
}
|
||||||
|
Q_UNREACHABLE();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(const JinjaMessage &a, const JinjaMessage &b)
|
||||||
|
{
|
||||||
|
if (a.m_item == b.m_item)
|
||||||
|
return true;
|
||||||
|
const auto &[ia, ib] = std::tie(*a.m_item, *b.m_item);
|
||||||
|
auto type = ia.type();
|
||||||
|
if (type != ib.type() || ia.value != ib.value)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
switch (type) {
|
||||||
|
using enum ChatItem::Type;
|
||||||
|
case System:
|
||||||
|
case Response:
|
||||||
|
return true;
|
||||||
|
case Prompt:
|
||||||
|
return ia.sources == ib.sources && ia.promptAttachments == ib.promptAttachments;
|
||||||
|
}
|
||||||
|
Q_UNREACHABLE();
|
||||||
|
}
|
||||||
|
|
||||||
|
const JinjaFieldMap<JinjaMessage> JinjaMessage::s_fields = {
|
||||||
|
{ "role", [](auto &m) {
|
||||||
|
switch (m.item().type()) {
|
||||||
|
using enum ChatItem::Type;
|
||||||
|
case System: return "system"sv;
|
||||||
|
case Prompt: return "user"sv;
|
||||||
|
case Response: return "assistant"sv;
|
||||||
|
}
|
||||||
|
Q_UNREACHABLE();
|
||||||
|
} },
|
||||||
|
{ "content", [](auto &m) {
|
||||||
|
if (m.version() == 0 && m.item().type() == ChatItem::Type::Prompt)
|
||||||
|
return m.item().bakedPrompt().toStdString();
|
||||||
|
return m.item().value.toStdString();
|
||||||
|
} },
|
||||||
|
{ "sources", [](auto &m) {
|
||||||
|
auto sources = m.item().sources | views::transform([](auto &r) {
|
||||||
|
return jinja2::GenericMap([map = std::make_shared<JinjaResultInfo>(r)] { return map.get(); });
|
||||||
|
});
|
||||||
|
return jinja2::ValuesList(sources.begin(), sources.end());
|
||||||
|
} },
|
||||||
|
{ "prompt_attachments", [](auto &m) {
|
||||||
|
auto attachments = m.item().promptAttachments | views::transform([](auto &pa) {
|
||||||
|
return jinja2::GenericMap([map = std::make_shared<JinjaPromptAttachment>(pa)] { return map.get(); });
|
||||||
|
});
|
||||||
|
return jinja2::ValuesList(attachments.begin(), attachments.end());
|
||||||
|
} },
|
||||||
|
};
|
116
gpt4all-chat/src/jinja_helpers.h
Normal file
116
gpt4all-chat/src/jinja_helpers.h
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "chatmodel.h"
|
||||||
|
#include "database.h"
|
||||||
|
|
||||||
|
#include <jinja2cpp/value.h>
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <ranges>
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include <QtGlobal>
|
||||||
|
|
||||||
|
namespace views = std::views;
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using JinjaFieldMap = std::unordered_map<std::string_view, std::function<jinja2::Value (const T &)>>;
|
||||||
|
|
||||||
|
template <typename Derived>
|
||||||
|
class JinjaComparable : public jinja2::IMapItemAccessor {
|
||||||
|
public:
|
||||||
|
JinjaComparable() = default;
|
||||||
|
|
||||||
|
bool IsEqual(const jinja2::IComparable &other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Q_DISABLE_COPY_MOVE(JinjaComparable)
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Derived>
|
||||||
|
class JinjaHelper : public JinjaComparable<Derived> {
|
||||||
|
public:
|
||||||
|
size_t GetSize() const override
|
||||||
|
{ return Derived::s_fields.size(); }
|
||||||
|
|
||||||
|
bool HasValue(const std::string &name) const override
|
||||||
|
{ return Derived::s_fields.contains(name); }
|
||||||
|
|
||||||
|
jinja2::Value GetValueByName(const std::string &name) const override;
|
||||||
|
|
||||||
|
std::vector<std::string> GetKeys() const override
|
||||||
|
{ auto keys = views::elements<0>(Derived::s_fields); return { keys.begin(), keys.end() }; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class JinjaResultInfo : public JinjaHelper<JinjaResultInfo> {
|
||||||
|
public:
|
||||||
|
explicit JinjaResultInfo(const ResultInfo &source) noexcept
|
||||||
|
: m_source(&source) {}
|
||||||
|
|
||||||
|
~JinjaResultInfo() override;
|
||||||
|
|
||||||
|
const ResultInfo &value() const { return *m_source; }
|
||||||
|
|
||||||
|
friend bool operator==(const JinjaResultInfo &a, const JinjaResultInfo &b)
|
||||||
|
{ return a.m_source == b.m_source || *a.m_source == *b.m_source; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
static const JinjaFieldMap<ResultInfo> s_fields;
|
||||||
|
const ResultInfo *m_source;
|
||||||
|
|
||||||
|
friend class JinjaHelper<JinjaResultInfo>;
|
||||||
|
};
|
||||||
|
|
||||||
|
class JinjaPromptAttachment : public JinjaHelper<JinjaPromptAttachment> {
|
||||||
|
public:
|
||||||
|
explicit JinjaPromptAttachment(const PromptAttachment &attachment) noexcept
|
||||||
|
: m_attachment(&attachment) {}
|
||||||
|
|
||||||
|
~JinjaPromptAttachment() override;
|
||||||
|
|
||||||
|
const PromptAttachment &value() const { return *m_attachment; }
|
||||||
|
|
||||||
|
friend bool operator==(const JinjaPromptAttachment &a, const JinjaPromptAttachment &b)
|
||||||
|
{ return a.m_attachment == b.m_attachment || *a.m_attachment == *b.m_attachment; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
static const JinjaFieldMap<PromptAttachment> s_fields;
|
||||||
|
const PromptAttachment *m_attachment;
|
||||||
|
|
||||||
|
friend class JinjaHelper<JinjaPromptAttachment>;
|
||||||
|
};
|
||||||
|
|
||||||
|
class JinjaMessage : public JinjaHelper<JinjaMessage> {
|
||||||
|
public:
|
||||||
|
explicit JinjaMessage(uint version, const ChatItem &item) noexcept
|
||||||
|
: m_version(version), m_item(&item) {}
|
||||||
|
|
||||||
|
const JinjaMessage &value () const { return *this; }
|
||||||
|
uint version() const { return m_version; }
|
||||||
|
const ChatItem &item () const { return *m_item; }
|
||||||
|
|
||||||
|
size_t GetSize() const override { return keys().size(); }
|
||||||
|
bool HasValue(const std::string &name) const override { return keys().contains(name); }
|
||||||
|
|
||||||
|
jinja2::Value GetValueByName(const std::string &name) const override
|
||||||
|
{ return HasValue(name) ? JinjaHelper::GetValueByName(name) : jinja2::EmptyValue(); }
|
||||||
|
|
||||||
|
std::vector<std::string> GetKeys() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
auto keys() const -> const std::unordered_set<std::string_view> &;
|
||||||
|
|
||||||
|
private:
|
||||||
|
static const JinjaFieldMap<JinjaMessage> s_fields;
|
||||||
|
uint m_version;
|
||||||
|
const ChatItem *m_item;
|
||||||
|
|
||||||
|
friend class JinjaHelper<JinjaMessage>;
|
||||||
|
friend bool operator==(const JinjaMessage &a, const JinjaMessage &b);
|
||||||
|
};
|
||||||
|
|
||||||
|
#include "jinja_helpers.inl"
|
17
gpt4all-chat/src/jinja_helpers.inl
Normal file
17
gpt4all-chat/src/jinja_helpers.inl
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
template <typename D>
|
||||||
|
bool JinjaComparable<D>::IsEqual(const jinja2::IComparable &other) const
|
||||||
|
{
|
||||||
|
if (auto *omsg = dynamic_cast<const D *>(&other))
|
||||||
|
return *static_cast<const D *>(this) == *omsg;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename D>
|
||||||
|
jinja2::Value JinjaHelper<D>::GetValueByName(const std::string &name) const
|
||||||
|
{
|
||||||
|
if (auto it = D::s_fields.find(name); it != D::s_fields.end()) {
|
||||||
|
auto [_, func] = *it;
|
||||||
|
return func(static_cast<const D *>(this)->value());
|
||||||
|
}
|
||||||
|
return jinja2::EmptyValue();
|
||||||
|
}
|
@ -12,12 +12,16 @@
|
|||||||
#include <singleapplication.h>
|
#include <singleapplication.h>
|
||||||
|
|
||||||
#include <QCoreApplication>
|
#include <QCoreApplication>
|
||||||
|
#include <QFont>
|
||||||
|
#include <QFontDatabase>
|
||||||
#include <QObject>
|
#include <QObject>
|
||||||
#include <QQmlApplicationEngine>
|
#include <QQmlApplicationEngine>
|
||||||
|
#include <QQmlContext>
|
||||||
#include <QQuickWindow>
|
#include <QQuickWindow>
|
||||||
#include <QSettings>
|
#include <QSettings>
|
||||||
#include <QString>
|
#include <QString>
|
||||||
#include <QUrl>
|
#include <QUrl>
|
||||||
|
#include <QVariant>
|
||||||
#include <Qt>
|
#include <Qt>
|
||||||
|
|
||||||
#ifdef Q_OS_LINUX
|
#ifdef Q_OS_LINUX
|
||||||
@ -91,18 +95,22 @@ int main(int argc, char *argv[])
|
|||||||
|
|
||||||
// Set the local and language translation before the qml engine has even been started. This will
|
// Set the local and language translation before the qml engine has even been started. This will
|
||||||
// use the default system locale unless the user has explicitly set it to use a different one.
|
// use the default system locale unless the user has explicitly set it to use a different one.
|
||||||
MySettings::globalInstance()->setLanguageAndLocale();
|
auto *mySettings = MySettings::globalInstance();
|
||||||
|
mySettings->setLanguageAndLocale();
|
||||||
|
|
||||||
QQmlApplicationEngine engine;
|
QQmlApplicationEngine engine;
|
||||||
|
|
||||||
// Add a connection here from MySettings::languageAndLocaleChanged signal to a lambda slot where I can call
|
// Add a connection here from MySettings::languageAndLocaleChanged signal to a lambda slot where I can call
|
||||||
// engine.uiLanguage property
|
// engine.uiLanguage property
|
||||||
QObject::connect(MySettings::globalInstance(), &MySettings::languageAndLocaleChanged, [&engine]() {
|
QObject::connect(mySettings, &MySettings::languageAndLocaleChanged, [&engine]() {
|
||||||
engine.setUiLanguage(MySettings::globalInstance()->languageAndLocale());
|
engine.setUiLanguage(MySettings::globalInstance()->languageAndLocale());
|
||||||
});
|
});
|
||||||
|
|
||||||
qmlRegisterSingletonInstance("mysettings", 1, 0, "MySettings", MySettings::globalInstance());
|
auto *modelList = ModelList::globalInstance();
|
||||||
qmlRegisterSingletonInstance("modellist", 1, 0, "ModelList", ModelList::globalInstance());
|
QObject::connect(modelList, &ModelList::dataChanged, mySettings, &MySettings::onModelInfoChanged);
|
||||||
|
|
||||||
|
qmlRegisterSingletonInstance("mysettings", 1, 0, "MySettings", mySettings);
|
||||||
|
qmlRegisterSingletonInstance("modellist", 1, 0, "ModelList", modelList);
|
||||||
qmlRegisterSingletonInstance("chatlistmodel", 1, 0, "ChatListModel", ChatListModel::globalInstance());
|
qmlRegisterSingletonInstance("chatlistmodel", 1, 0, "ChatListModel", ChatListModel::globalInstance());
|
||||||
qmlRegisterSingletonInstance("llm", 1, 0, "LLM", LLM::globalInstance());
|
qmlRegisterSingletonInstance("llm", 1, 0, "LLM", LLM::globalInstance());
|
||||||
qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance());
|
qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance());
|
||||||
@ -110,6 +118,11 @@ int main(int argc, char *argv[])
|
|||||||
qmlRegisterSingletonInstance("localdocs", 1, 0, "LocalDocs", LocalDocs::globalInstance());
|
qmlRegisterSingletonInstance("localdocs", 1, 0, "LocalDocs", LocalDocs::globalInstance());
|
||||||
qmlRegisterUncreatableMetaObject(MySettingsEnums::staticMetaObject, "mysettingsenums", 1, 0, "MySettingsEnums", "Error: only enums");
|
qmlRegisterUncreatableMetaObject(MySettingsEnums::staticMetaObject, "mysettingsenums", 1, 0, "MySettingsEnums", "Error: only enums");
|
||||||
|
|
||||||
|
{
|
||||||
|
auto fixedFont = QFontDatabase::systemFont(QFontDatabase::FixedFont);
|
||||||
|
engine.rootContext()->setContextProperty("fixedFont", fixedFont);
|
||||||
|
}
|
||||||
|
|
||||||
const QUrl url(u"qrc:/gpt4all/main.qml"_s);
|
const QUrl url(u"qrc:/gpt4all/main.qml"_s);
|
||||||
|
|
||||||
QObject::connect(&engine, &QQmlApplicationEngine::objectCreated,
|
QObject::connect(&engine, &QQmlApplicationEngine::objectCreated,
|
||||||
|
@ -316,26 +316,44 @@ void ModelInfo::setRepeatPenaltyTokens(int t)
|
|||||||
m_repeatPenaltyTokens = t;
|
m_repeatPenaltyTokens = t;
|
||||||
}
|
}
|
||||||
|
|
||||||
QString ModelInfo::promptTemplate() const
|
QVariant ModelInfo::defaultChatTemplate() const
|
||||||
{
|
{
|
||||||
return MySettings::globalInstance()->modelPromptTemplate(*this);
|
auto res = m_chatTemplate.or_else([this] -> std::optional<QString> {
|
||||||
|
if (!installed || isOnline)
|
||||||
|
return std::nullopt;
|
||||||
|
if (!m_modelChatTemplate) {
|
||||||
|
auto path = (dirpath + filename()).toUtf8();
|
||||||
|
auto res = LLModel::Implementation::chatTemplate(path.constData());
|
||||||
|
if (res) {
|
||||||
|
m_modelChatTemplate = QString::fromStdString(*res);
|
||||||
|
} else {
|
||||||
|
qWarning().nospace() << "failed to get chat template for " << filename() << ": " << res.error().c_str();
|
||||||
|
m_modelChatTemplate = QString(); // do not retry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (m_modelChatTemplate->isNull())
|
||||||
|
return std::nullopt;
|
||||||
|
return m_modelChatTemplate;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (res)
|
||||||
|
return std::move(*res);
|
||||||
|
return QVariant::fromValue(nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModelInfo::setPromptTemplate(const QString &t)
|
auto ModelInfo::chatTemplate() const -> UpgradeableSetting
|
||||||
{
|
{
|
||||||
if (shouldSaveMetadata()) MySettings::globalInstance()->setModelPromptTemplate(*this, t, true /*force*/);
|
return MySettings::globalInstance()->modelChatTemplate(*this);
|
||||||
m_promptTemplate = t;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
QString ModelInfo::systemPrompt() const
|
QString ModelInfo::defaultSystemMessage() const
|
||||||
{
|
{
|
||||||
return MySettings::globalInstance()->modelSystemPrompt(*this);
|
return m_systemMessage;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModelInfo::setSystemPrompt(const QString &p)
|
auto ModelInfo::systemMessage() const -> UpgradeableSetting
|
||||||
{
|
{
|
||||||
if (shouldSaveMetadata()) MySettings::globalInstance()->setModelSystemPrompt(*this, p, true /*force*/);
|
return MySettings::globalInstance()->modelSystemMessage(*this);
|
||||||
m_systemPrompt = p;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
QString ModelInfo::chatNamePrompt() const
|
QString ModelInfo::chatNamePrompt() const
|
||||||
@ -360,39 +378,41 @@ void ModelInfo::setSuggestedFollowUpPrompt(const QString &p)
|
|||||||
m_suggestedFollowUpPrompt = p;
|
m_suggestedFollowUpPrompt = p;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME(jared): this should not be used for model settings that have meaningful defaults, such as temperature
|
||||||
bool ModelInfo::shouldSaveMetadata() const
|
bool ModelInfo::shouldSaveMetadata() const
|
||||||
{
|
{
|
||||||
return installed && (isClone() || isDiscovered() || description() == "" /*indicates sideloaded*/);
|
return installed && (isClone() || isDiscovered() || description() == "" /*indicates sideloaded*/);
|
||||||
}
|
}
|
||||||
|
|
||||||
QVariantMap ModelInfo::getFields() const
|
QVariant ModelInfo::getField(QLatin1StringView name) const
|
||||||
{
|
{
|
||||||
return {
|
static const std::unordered_map<QLatin1StringView, QVariant(*)(const ModelInfo &)> s_fields = {
|
||||||
{ "filename", m_filename },
|
{ "filename"_L1, [](auto &i) -> QVariant { return i.m_filename; } },
|
||||||
{ "description", m_description },
|
{ "description"_L1, [](auto &i) -> QVariant { return i.m_description; } },
|
||||||
{ "url", m_url },
|
{ "url"_L1, [](auto &i) -> QVariant { return i.m_url; } },
|
||||||
{ "quant", m_quant },
|
{ "quant"_L1, [](auto &i) -> QVariant { return i.m_quant; } },
|
||||||
{ "type", m_type },
|
{ "type"_L1, [](auto &i) -> QVariant { return i.m_type; } },
|
||||||
{ "isClone", m_isClone },
|
{ "isClone"_L1, [](auto &i) -> QVariant { return i.m_isClone; } },
|
||||||
{ "isDiscovered", m_isDiscovered },
|
{ "isDiscovered"_L1, [](auto &i) -> QVariant { return i.m_isDiscovered; } },
|
||||||
{ "likes", m_likes },
|
{ "likes"_L1, [](auto &i) -> QVariant { return i.m_likes; } },
|
||||||
{ "downloads", m_downloads },
|
{ "downloads"_L1, [](auto &i) -> QVariant { return i.m_downloads; } },
|
||||||
{ "recency", m_recency },
|
{ "recency"_L1, [](auto &i) -> QVariant { return i.m_recency; } },
|
||||||
{ "temperature", m_temperature },
|
{ "temperature"_L1, [](auto &i) -> QVariant { return i.m_temperature; } },
|
||||||
{ "topP", m_topP },
|
{ "topP"_L1, [](auto &i) -> QVariant { return i.m_topP; } },
|
||||||
{ "minP", m_minP },
|
{ "minP"_L1, [](auto &i) -> QVariant { return i.m_minP; } },
|
||||||
{ "topK", m_topK },
|
{ "topK"_L1, [](auto &i) -> QVariant { return i.m_topK; } },
|
||||||
{ "maxLength", m_maxLength },
|
{ "maxLength"_L1, [](auto &i) -> QVariant { return i.m_maxLength; } },
|
||||||
{ "promptBatchSize", m_promptBatchSize },
|
{ "promptBatchSize"_L1, [](auto &i) -> QVariant { return i.m_promptBatchSize; } },
|
||||||
{ "contextLength", m_contextLength },
|
{ "contextLength"_L1, [](auto &i) -> QVariant { return i.m_contextLength; } },
|
||||||
{ "gpuLayers", m_gpuLayers },
|
{ "gpuLayers"_L1, [](auto &i) -> QVariant { return i.m_gpuLayers; } },
|
||||||
{ "repeatPenalty", m_repeatPenalty },
|
{ "repeatPenalty"_L1, [](auto &i) -> QVariant { return i.m_repeatPenalty; } },
|
||||||
{ "repeatPenaltyTokens", m_repeatPenaltyTokens },
|
{ "repeatPenaltyTokens"_L1, [](auto &i) -> QVariant { return i.m_repeatPenaltyTokens; } },
|
||||||
{ "promptTemplate", m_promptTemplate },
|
{ "chatTemplate"_L1, [](auto &i) -> QVariant { return i.defaultChatTemplate(); } },
|
||||||
{ "systemPrompt", m_systemPrompt },
|
{ "systemMessage"_L1, [](auto &i) -> QVariant { return i.m_systemMessage; } },
|
||||||
{ "chatNamePrompt", m_chatNamePrompt },
|
{ "chatNamePrompt"_L1, [](auto &i) -> QVariant { return i.m_chatNamePrompt; } },
|
||||||
{ "suggestedFollowUpPrompt", m_suggestedFollowUpPrompt },
|
{ "suggestedFollowUpPrompt"_L1, [](auto &i) -> QVariant { return i.m_suggestedFollowUpPrompt; } },
|
||||||
};
|
};
|
||||||
|
return s_fields.at(name)(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
InstalledModels::InstalledModels(QObject *parent, bool selectable)
|
InstalledModels::InstalledModels(QObject *parent, bool selectable)
|
||||||
@ -491,31 +511,48 @@ ModelList::ModelList()
|
|||||||
m_selectableModels->setSourceModel(this);
|
m_selectableModels->setSourceModel(this);
|
||||||
m_downloadableModels->setSourceModel(this);
|
m_downloadableModels->setSourceModel(this);
|
||||||
|
|
||||||
connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromDirectory);
|
auto *mySettings = MySettings::globalInstance();
|
||||||
connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromJson);
|
connect(mySettings, &MySettings::nameChanged, this, &ModelList::updateDataForSettings );
|
||||||
connect(MySettings::globalInstance(), &MySettings::modelPathChanged, this, &ModelList::updateModelsFromSettings);
|
connect(mySettings, &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings );
|
||||||
connect(MySettings::globalInstance(), &MySettings::nameChanged, this, &ModelList::updateDataForSettings);
|
connect(mySettings, &MySettings::topPChanged, this, &ModelList::updateDataForSettings );
|
||||||
connect(MySettings::globalInstance(), &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings);
|
connect(mySettings, &MySettings::minPChanged, this, &ModelList::updateDataForSettings );
|
||||||
connect(MySettings::globalInstance(), &MySettings::topPChanged, this, &ModelList::updateDataForSettings);
|
connect(mySettings, &MySettings::topKChanged, this, &ModelList::updateDataForSettings );
|
||||||
connect(MySettings::globalInstance(), &MySettings::minPChanged, this, &ModelList::updateDataForSettings);
|
connect(mySettings, &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings );
|
||||||
connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings);
|
connect(mySettings, &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings );
|
||||||
connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings);
|
connect(mySettings, &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings );
|
||||||
connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings);
|
connect(mySettings, &MySettings::gpuLayersChanged, this, &ModelList::updateDataForSettings );
|
||||||
connect(MySettings::globalInstance(), &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings);
|
connect(mySettings, &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings );
|
||||||
connect(MySettings::globalInstance(), &MySettings::gpuLayersChanged, this, &ModelList::updateDataForSettings);
|
connect(mySettings, &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings );
|
||||||
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings);
|
connect(mySettings, &MySettings::chatTemplateChanged, this, &ModelList::maybeUpdateDataForSettings);
|
||||||
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);
|
connect(mySettings, &MySettings::systemMessageChanged, this, &ModelList::maybeUpdateDataForSettings);
|
||||||
connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings);
|
|
||||||
connect(MySettings::globalInstance(), &MySettings::systemPromptChanged, this, &ModelList::updateDataForSettings);
|
connect(this, &ModelList::dataChanged, this, &ModelList::onDataChanged);
|
||||||
|
|
||||||
connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, &ModelList::handleSslErrors);
|
connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, &ModelList::handleSslErrors);
|
||||||
|
|
||||||
updateModelsFromJson();
|
updateModelsFromJson();
|
||||||
updateModelsFromSettings();
|
updateModelsFromSettings();
|
||||||
updateModelsFromDirectory();
|
updateModelsFromDirectory();
|
||||||
|
|
||||||
|
connect(mySettings, &MySettings::modelPathChanged, this, &ModelList::updateModelsFromDirectory);
|
||||||
|
connect(mySettings, &MySettings::modelPathChanged, this, &ModelList::updateModelsFromJson );
|
||||||
|
connect(mySettings, &MySettings::modelPathChanged, this, &ModelList::updateModelsFromSettings );
|
||||||
|
|
||||||
QCoreApplication::instance()->installEventFilter(this);
|
QCoreApplication::instance()->installEventFilter(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// an easier way to listen for model info and setting changes
|
||||||
|
void ModelList::onDataChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList<int> &roles)
|
||||||
|
{
|
||||||
|
Q_UNUSED(roles)
|
||||||
|
for (int row = topLeft.row(); row <= bottomRight.row(); row++) {
|
||||||
|
auto index = topLeft.siblingAtRow(row);
|
||||||
|
auto id = index.data(ModelList::IdRole).toString();
|
||||||
|
if (auto info = modelInfo(id); !info.id().isNull())
|
||||||
|
emit modelInfoChanged(info);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
QString ModelList::compatibleModelNameHash(QUrl baseUrl, QString modelName) {
|
QString ModelList::compatibleModelNameHash(QUrl baseUrl, QString modelName) {
|
||||||
QCryptographicHash sha256(QCryptographicHash::Sha256);
|
QCryptographicHash sha256(QCryptographicHash::Sha256);
|
||||||
sha256.addData((baseUrl.toString() + "_" + modelName).toUtf8());
|
sha256.addData((baseUrl.toString() + "_" + modelName).toUtf8());
|
||||||
@ -776,10 +813,10 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
|
|||||||
return info->repeatPenalty();
|
return info->repeatPenalty();
|
||||||
case RepeatPenaltyTokensRole:
|
case RepeatPenaltyTokensRole:
|
||||||
return info->repeatPenaltyTokens();
|
return info->repeatPenaltyTokens();
|
||||||
case PromptTemplateRole:
|
case ChatTemplateRole:
|
||||||
return info->promptTemplate();
|
return QVariant::fromValue(info->chatTemplate());
|
||||||
case SystemPromptRole:
|
case SystemMessageRole:
|
||||||
return info->systemPrompt();
|
return QVariant::fromValue(info->systemMessage());
|
||||||
case ChatNamePromptRole:
|
case ChatNamePromptRole:
|
||||||
return info->chatNamePrompt();
|
return info->chatNamePrompt();
|
||||||
case SuggestedFollowUpPromptRole:
|
case SuggestedFollowUpPromptRole:
|
||||||
@ -952,10 +989,10 @@ void ModelList::updateData(const QString &id, const QVector<QPair<int, QVariant>
|
|||||||
info->setRepeatPenalty(value.toDouble()); break;
|
info->setRepeatPenalty(value.toDouble()); break;
|
||||||
case RepeatPenaltyTokensRole:
|
case RepeatPenaltyTokensRole:
|
||||||
info->setRepeatPenaltyTokens(value.toInt()); break;
|
info->setRepeatPenaltyTokens(value.toInt()); break;
|
||||||
case PromptTemplateRole:
|
case ChatTemplateRole:
|
||||||
info->setPromptTemplate(value.toString()); break;
|
info->m_chatTemplate = value.toString(); break;
|
||||||
case SystemPromptRole:
|
case SystemMessageRole:
|
||||||
info->setSystemPrompt(value.toString()); break;
|
info->m_systemMessage = value.toString(); break;
|
||||||
case ChatNamePromptRole:
|
case ChatNamePromptRole:
|
||||||
info->setChatNamePrompt(value.toString()); break;
|
info->setChatNamePrompt(value.toString()); break;
|
||||||
case SuggestedFollowUpPromptRole:
|
case SuggestedFollowUpPromptRole:
|
||||||
@ -1056,11 +1093,11 @@ ModelInfo ModelList::modelInfo(const QString &id) const
|
|||||||
return *m_modelMap.value(id);
|
return *m_modelMap.value(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
ModelInfo ModelList::modelInfoByFilename(const QString &filename) const
|
ModelInfo ModelList::modelInfoByFilename(const QString &filename, bool allowClone) const
|
||||||
{
|
{
|
||||||
QMutexLocker locker(&m_mutex);
|
QMutexLocker locker(&m_mutex);
|
||||||
for (ModelInfo *info : m_models)
|
for (ModelInfo *info : m_models)
|
||||||
if (info->filename() == filename)
|
if (info->filename() == filename && (allowClone || !info->isClone()))
|
||||||
return *info;
|
return *info;
|
||||||
return ModelInfo();
|
return ModelInfo();
|
||||||
}
|
}
|
||||||
@ -1080,6 +1117,20 @@ QString ModelList::clone(const ModelInfo &model)
|
|||||||
const QString id = Network::globalInstance()->generateUniqueId();
|
const QString id = Network::globalInstance()->generateUniqueId();
|
||||||
addModel(id);
|
addModel(id);
|
||||||
|
|
||||||
|
QString chatTemplate, systemMessage;
|
||||||
|
if (auto tmpl = model.chatTemplate().asModern()) {
|
||||||
|
chatTemplate = *tmpl;
|
||||||
|
} else {
|
||||||
|
qWarning("ModelList Warning: attempted to clone model with legacy chat template");
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
if (auto msg = model.systemMessage().asModern()) {
|
||||||
|
systemMessage = *msg;
|
||||||
|
} else {
|
||||||
|
qWarning("ModelList Warning: attempted to clone model with legacy system message");
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
QVector<QPair<int, QVariant>> data {
|
QVector<QPair<int, QVariant>> data {
|
||||||
{ ModelList::InstalledRole, model.installed },
|
{ ModelList::InstalledRole, model.installed },
|
||||||
{ ModelList::IsCloneRole, true },
|
{ ModelList::IsCloneRole, true },
|
||||||
@ -1099,8 +1150,8 @@ QString ModelList::clone(const ModelInfo &model)
|
|||||||
{ ModelList::GpuLayersRole, model.gpuLayers() },
|
{ ModelList::GpuLayersRole, model.gpuLayers() },
|
||||||
{ ModelList::RepeatPenaltyRole, model.repeatPenalty() },
|
{ ModelList::RepeatPenaltyRole, model.repeatPenalty() },
|
||||||
{ ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens() },
|
{ ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens() },
|
||||||
{ ModelList::PromptTemplateRole, model.promptTemplate() },
|
{ ModelList::ChatTemplateRole, chatTemplate },
|
||||||
{ ModelList::SystemPromptRole, model.systemPrompt() },
|
{ ModelList::SystemMessageRole, systemMessage },
|
||||||
{ ModelList::ChatNamePromptRole, model.chatNamePrompt() },
|
{ ModelList::ChatNamePromptRole, model.chatNamePrompt() },
|
||||||
{ ModelList::SuggestedFollowUpPromptRole, model.suggestedFollowUpPrompt() },
|
{ ModelList::SuggestedFollowUpPromptRole, model.suggestedFollowUpPrompt() },
|
||||||
};
|
};
|
||||||
@ -1125,21 +1176,23 @@ void ModelList::removeInstalled(const ModelInfo &model)
|
|||||||
removeInternal(model);
|
removeInternal(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ModelList::indexByModelId(const QString &id) const
|
||||||
|
{
|
||||||
|
QMutexLocker locker(&m_mutex);
|
||||||
|
if (auto it = m_modelMap.find(id); it != m_modelMap.cend())
|
||||||
|
return m_models.indexOf(*it);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
void ModelList::removeInternal(const ModelInfo &model)
|
void ModelList::removeInternal(const ModelInfo &model)
|
||||||
{
|
{
|
||||||
const bool hasModel = contains(model.id());
|
int indexOfModel = indexByModelId(model.id());
|
||||||
Q_ASSERT(hasModel);
|
Q_ASSERT(indexOfModel != -1);
|
||||||
if (!hasModel) {
|
if (indexOfModel == -1) {
|
||||||
qWarning() << "ERROR: model list does not contain" << model.id();
|
qWarning() << "ERROR: model list does not contain" << model.id();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int indexOfModel = 0;
|
|
||||||
{
|
|
||||||
QMutexLocker locker(&m_mutex);
|
|
||||||
ModelInfo *info = m_modelMap.value(model.id());
|
|
||||||
indexOfModel = m_models.indexOf(info);
|
|
||||||
}
|
|
||||||
beginRemoveRows(QModelIndex(), indexOfModel, indexOfModel);
|
beginRemoveRows(QModelIndex(), indexOfModel, indexOfModel);
|
||||||
{
|
{
|
||||||
QMutexLocker locker(&m_mutex);
|
QMutexLocker locker(&m_mutex);
|
||||||
@ -1314,8 +1367,6 @@ void ModelList::processModelDirectory(const QString &path)
|
|||||||
// The description is hard-coded into "GPT4All.ini" due to performance issue.
|
// The description is hard-coded into "GPT4All.ini" due to performance issue.
|
||||||
// If the description goes to be dynamic from its .rmodel file, it will get high I/O usage while using the ModelList.
|
// If the description goes to be dynamic from its .rmodel file, it will get high I/O usage while using the ModelList.
|
||||||
data.append({ DescriptionRole, description });
|
data.append({ DescriptionRole, description });
|
||||||
// Prompt template should be clear while using ChatML format which is using in most of OpenAI-Compatible API server.
|
|
||||||
data.append({ PromptTemplateRole, "%1" });
|
|
||||||
}
|
}
|
||||||
updateData(id, data);
|
updateData(id, data);
|
||||||
}
|
}
|
||||||
@ -1451,9 +1502,20 @@ void ModelList::handleSslErrors(QNetworkReply *reply, const QList<QSslError> &er
|
|||||||
qWarning() << "ERROR: Received ssl error:" << e.errorString() << "for" << url;
|
qWarning() << "ERROR: Received ssl error:" << e.errorString() << "for" << url;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ModelList::maybeUpdateDataForSettings(const ModelInfo &info, bool fromInfo)
|
||||||
|
{
|
||||||
|
// ignore updates that were *because* of a dataChanged - would cause a circular dependency
|
||||||
|
int idx;
|
||||||
|
if (!fromInfo && (idx = indexByModelId(info.id())) != -1) {
|
||||||
|
emit dataChanged(index(idx, 0), index(idx, 0));
|
||||||
|
emit selectableModelListChanged();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ModelList::updateDataForSettings()
|
void ModelList::updateDataForSettings()
|
||||||
{
|
{
|
||||||
emit dataChanged(index(0, 0), index(m_models.size() - 1, 0));
|
emit dataChanged(index(0, 0), index(m_models.size() - 1, 0));
|
||||||
|
emit selectableModelListChanged();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
|
void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
|
||||||
@ -1560,10 +1622,10 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
|
|||||||
data.append({ ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble() });
|
data.append({ ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble() });
|
||||||
if (obj.contains("repeatPenaltyTokens"))
|
if (obj.contains("repeatPenaltyTokens"))
|
||||||
data.append({ ModelList::RepeatPenaltyTokensRole, obj["repeatPenaltyTokens"].toInt() });
|
data.append({ ModelList::RepeatPenaltyTokensRole, obj["repeatPenaltyTokens"].toInt() });
|
||||||
if (obj.contains("promptTemplate"))
|
if (auto it = obj.find("chatTemplate"_L1); it != obj.end())
|
||||||
data.append({ ModelList::PromptTemplateRole, obj["promptTemplate"].toString() });
|
data.append({ ModelList::ChatTemplateRole, it->toString() });
|
||||||
if (obj.contains("systemPrompt"))
|
if (auto it = obj.find("systemMessage"_L1); it != obj.end())
|
||||||
data.append({ ModelList::SystemPromptRole, obj["systemPrompt"].toString() });
|
data.append({ ModelList::SystemMessageRole, it->toString() });
|
||||||
updateData(id, data);
|
updateData(id, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1755,6 +1817,9 @@ void ModelList::updateDiscoveredInstalled(const ModelInfo &info)
|
|||||||
updateData(info.id(), data);
|
updateData(info.id(), data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME(jared): This should only contain fields without reasonable defaults such as name, description, and URL.
|
||||||
|
// For other settings, there is no authoritative value and we should load the setting lazily like we do
|
||||||
|
// for any other override.
|
||||||
void ModelList::updateModelsFromSettings()
|
void ModelList::updateModelsFromSettings()
|
||||||
{
|
{
|
||||||
QSettings settings;
|
QSettings settings;
|
||||||
@ -1769,12 +1834,27 @@ void ModelList::updateModelsFromSettings()
|
|||||||
|
|
||||||
// If we can't find the corresponding file, then ignore it as this reflects a stale model.
|
// If we can't find the corresponding file, then ignore it as this reflects a stale model.
|
||||||
// The file could have been deleted manually by the user for instance or temporarily renamed.
|
// The file could have been deleted manually by the user for instance or temporarily renamed.
|
||||||
if (!settings.contains(g + "/filename") || !modelExists(settings.value(g + "/filename").toString()))
|
QString filename;
|
||||||
continue;
|
{
|
||||||
|
auto value = settings.value(u"%1/filename"_s.arg(g));
|
||||||
|
if (!value.isValid() || !modelExists(filename = value.toString()))
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
QVector<QPair<int, QVariant>> data;
|
||||||
|
|
||||||
|
// load data from base model
|
||||||
|
// FIXME(jared): how does "Restore Defaults" work for other settings of clones which we don't do this for?
|
||||||
|
if (auto base = modelInfoByFilename(filename, /*allowClone*/ false); !base.id().isNull()) {
|
||||||
|
if (auto tmpl = base.m_chatTemplate)
|
||||||
|
data.append({ ModelList::ChatTemplateRole, *tmpl });
|
||||||
|
if (auto msg = base.m_systemMessage; !msg.isNull())
|
||||||
|
data.append({ ModelList::SystemMessageRole, msg });
|
||||||
|
}
|
||||||
|
|
||||||
addModel(id);
|
addModel(id);
|
||||||
|
|
||||||
QVector<QPair<int, QVariant>> data;
|
// load data from settings
|
||||||
if (settings.contains(g + "/name")) {
|
if (settings.contains(g + "/name")) {
|
||||||
const QString name = settings.value(g + "/name").toString();
|
const QString name = settings.value(g + "/name").toString();
|
||||||
data.append({ ModelList::NameRole, name });
|
data.append({ ModelList::NameRole, name });
|
||||||
@ -1859,14 +1939,6 @@ void ModelList::updateModelsFromSettings()
|
|||||||
const int repeatPenaltyTokens = settings.value(g + "/repeatPenaltyTokens").toInt();
|
const int repeatPenaltyTokens = settings.value(g + "/repeatPenaltyTokens").toInt();
|
||||||
data.append({ ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens });
|
data.append({ ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens });
|
||||||
}
|
}
|
||||||
if (settings.contains(g + "/promptTemplate")) {
|
|
||||||
const QString promptTemplate = settings.value(g + "/promptTemplate").toString();
|
|
||||||
data.append({ ModelList::PromptTemplateRole, promptTemplate });
|
|
||||||
}
|
|
||||||
if (settings.contains(g + "/systemPrompt")) {
|
|
||||||
const QString systemPrompt = settings.value(g + "/systemPrompt").toString();
|
|
||||||
data.append({ ModelList::SystemPromptRole, systemPrompt });
|
|
||||||
}
|
|
||||||
if (settings.contains(g + "/chatNamePrompt")) {
|
if (settings.contains(g + "/chatNamePrompt")) {
|
||||||
const QString chatNamePrompt = settings.value(g + "/chatNamePrompt").toString();
|
const QString chatNamePrompt = settings.value(g + "/chatNamePrompt").toString();
|
||||||
data.append({ ModelList::ChatNamePromptRole, chatNamePrompt });
|
data.append({ ModelList::ChatNamePromptRole, chatNamePrompt });
|
||||||
|
@ -5,12 +5,14 @@
|
|||||||
#include <QByteArray>
|
#include <QByteArray>
|
||||||
#include <QDateTime>
|
#include <QDateTime>
|
||||||
#include <QHash>
|
#include <QHash>
|
||||||
|
#include <QLatin1StringView>
|
||||||
#include <QList>
|
#include <QList>
|
||||||
#include <QMutex>
|
#include <QMutex>
|
||||||
#include <QNetworkAccessManager>
|
#include <QNetworkAccessManager>
|
||||||
#include <QNetworkReply>
|
#include <QNetworkReply>
|
||||||
#include <QObject>
|
#include <QObject>
|
||||||
#include <QPair>
|
#include <QPair>
|
||||||
|
#include <QQmlEngine>
|
||||||
#include <QSortFilterProxyModel>
|
#include <QSortFilterProxyModel>
|
||||||
#include <QSslError>
|
#include <QSslError>
|
||||||
#include <QString>
|
#include <QString>
|
||||||
@ -19,11 +21,53 @@
|
|||||||
#include <Qt>
|
#include <Qt>
|
||||||
#include <QtGlobal>
|
#include <QtGlobal>
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
using namespace Qt::Literals::StringLiterals;
|
using namespace Qt::Literals::StringLiterals;
|
||||||
|
|
||||||
|
|
||||||
|
class UpgradeableSetting {
|
||||||
|
Q_GADGET
|
||||||
|
QML_ANONYMOUS
|
||||||
|
|
||||||
|
// NOTE: Unset implies there is neither a value nor a default
|
||||||
|
enum class State { Unset, Legacy, Modern };
|
||||||
|
|
||||||
|
Q_PROPERTY(bool isSet READ isSet )
|
||||||
|
Q_PROPERTY(bool isLegacy READ isLegacy)
|
||||||
|
Q_PROPERTY(bool isModern READ isModern)
|
||||||
|
Q_PROPERTY(QVariant value READ value) // string or null
|
||||||
|
|
||||||
|
public:
|
||||||
|
struct legacy_tag_t { explicit legacy_tag_t() = default; };
|
||||||
|
static inline constexpr legacy_tag_t legacy_tag = legacy_tag_t();
|
||||||
|
|
||||||
|
UpgradeableSetting() : m_state(State::Unset ) {}
|
||||||
|
UpgradeableSetting(legacy_tag_t, QString value): m_state(State::Legacy), m_value(std::move(value)) {}
|
||||||
|
UpgradeableSetting( QString value): m_state(State::Modern), m_value(std::move(value)) {}
|
||||||
|
|
||||||
|
bool isSet () const { return m_state != State::Unset; }
|
||||||
|
bool isLegacy() const { return m_state == State::Legacy; }
|
||||||
|
bool isModern() const { return m_state == State::Modern; }
|
||||||
|
QVariant value () const { return m_state == State::Unset ? QVariant::fromValue(nullptr) : m_value; }
|
||||||
|
|
||||||
|
friend bool operator==(const UpgradeableSetting &a, const UpgradeableSetting &b)
|
||||||
|
{ return a.m_state == b.m_state && (a.m_state == State::Unset || a.m_value == b.m_value); }
|
||||||
|
|
||||||
|
// returns std::nullopt if there is a legacy template or it is not set
|
||||||
|
std::optional<QString> asModern() const
|
||||||
|
{
|
||||||
|
if (m_state == State::Modern)
|
||||||
|
return m_value;
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
State m_state;
|
||||||
|
QString m_value;
|
||||||
|
};
|
||||||
|
|
||||||
struct ModelInfo {
|
struct ModelInfo {
|
||||||
Q_GADGET
|
Q_GADGET
|
||||||
Q_PROPERTY(QString id READ id WRITE setId)
|
Q_PROPERTY(QString id READ id WRITE setId)
|
||||||
@ -69,8 +113,11 @@ struct ModelInfo {
|
|||||||
Q_PROPERTY(int maxGpuLayers READ maxGpuLayers)
|
Q_PROPERTY(int maxGpuLayers READ maxGpuLayers)
|
||||||
Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty)
|
Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty)
|
||||||
Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens)
|
Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens)
|
||||||
Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate)
|
// user-defined chat template and system message must be written through settings because of their legacy compat
|
||||||
Q_PROPERTY(QString systemPrompt READ systemPrompt WRITE setSystemPrompt)
|
Q_PROPERTY(QVariant defaultChatTemplate READ defaultChatTemplate )
|
||||||
|
Q_PROPERTY(UpgradeableSetting chatTemplate READ chatTemplate )
|
||||||
|
Q_PROPERTY(QString defaultSystemMessage READ defaultSystemMessage)
|
||||||
|
Q_PROPERTY(UpgradeableSetting systemMessage READ systemMessage )
|
||||||
Q_PROPERTY(QString chatNamePrompt READ chatNamePrompt WRITE setChatNamePrompt)
|
Q_PROPERTY(QString chatNamePrompt READ chatNamePrompt WRITE setChatNamePrompt)
|
||||||
Q_PROPERTY(QString suggestedFollowUpPrompt READ suggestedFollowUpPrompt WRITE setSuggestedFollowUpPrompt)
|
Q_PROPERTY(QString suggestedFollowUpPrompt READ suggestedFollowUpPrompt WRITE setSuggestedFollowUpPrompt)
|
||||||
Q_PROPERTY(int likes READ likes WRITE setLikes)
|
Q_PROPERTY(int likes READ likes WRITE setLikes)
|
||||||
@ -178,19 +225,22 @@ public:
|
|||||||
void setRepeatPenalty(double p);
|
void setRepeatPenalty(double p);
|
||||||
int repeatPenaltyTokens() const;
|
int repeatPenaltyTokens() const;
|
||||||
void setRepeatPenaltyTokens(int t);
|
void setRepeatPenaltyTokens(int t);
|
||||||
QString promptTemplate() const;
|
QVariant defaultChatTemplate() const;
|
||||||
void setPromptTemplate(const QString &t);
|
UpgradeableSetting chatTemplate() const;
|
||||||
QString systemPrompt() const;
|
QString defaultSystemMessage() const;
|
||||||
void setSystemPrompt(const QString &p);
|
UpgradeableSetting systemMessage() const;
|
||||||
QString chatNamePrompt() const;
|
QString chatNamePrompt() const;
|
||||||
void setChatNamePrompt(const QString &p);
|
void setChatNamePrompt(const QString &p);
|
||||||
QString suggestedFollowUpPrompt() const;
|
QString suggestedFollowUpPrompt() const;
|
||||||
void setSuggestedFollowUpPrompt(const QString &p);
|
void setSuggestedFollowUpPrompt(const QString &p);
|
||||||
|
|
||||||
|
// Some metadata must be saved to settings because it does not have a meaningful default from some other source.
|
||||||
|
// This is useful for fields such as name, description, and URL.
|
||||||
|
// It is true for any models that have not been installed from models.json.
|
||||||
bool shouldSaveMetadata() const;
|
bool shouldSaveMetadata() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
QVariantMap getFields() const;
|
QVariant getField(QLatin1StringView name) const;
|
||||||
|
|
||||||
QString m_id;
|
QString m_id;
|
||||||
QString m_name;
|
QString m_name;
|
||||||
@ -216,11 +266,13 @@ private:
|
|||||||
mutable int m_maxGpuLayers = -1;
|
mutable int m_maxGpuLayers = -1;
|
||||||
double m_repeatPenalty = 1.18;
|
double m_repeatPenalty = 1.18;
|
||||||
int m_repeatPenaltyTokens = 64;
|
int m_repeatPenaltyTokens = 64;
|
||||||
QString m_promptTemplate = "### Human:\n%1\n\n### Assistant:\n";
|
std::optional<QString> m_chatTemplate;
|
||||||
QString m_systemPrompt = "### System:\nYou are an AI assistant who gives a quality response to whatever humans ask of you.\n\n";
|
mutable std::optional<QString> m_modelChatTemplate;
|
||||||
|
QString m_systemMessage;
|
||||||
QString m_chatNamePrompt = "Describe the above conversation in seven words or less.";
|
QString m_chatNamePrompt = "Describe the above conversation in seven words or less.";
|
||||||
QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts.";
|
QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts.";
|
||||||
friend class MySettings;
|
friend class MySettings;
|
||||||
|
friend class ModelList;
|
||||||
};
|
};
|
||||||
Q_DECLARE_METATYPE(ModelInfo)
|
Q_DECLARE_METATYPE(ModelInfo)
|
||||||
|
|
||||||
@ -340,8 +392,8 @@ public:
|
|||||||
GpuLayersRole,
|
GpuLayersRole,
|
||||||
RepeatPenaltyRole,
|
RepeatPenaltyRole,
|
||||||
RepeatPenaltyTokensRole,
|
RepeatPenaltyTokensRole,
|
||||||
PromptTemplateRole,
|
ChatTemplateRole,
|
||||||
SystemPromptRole,
|
SystemMessageRole,
|
||||||
ChatNamePromptRole,
|
ChatNamePromptRole,
|
||||||
SuggestedFollowUpPromptRole,
|
SuggestedFollowUpPromptRole,
|
||||||
MinPRole,
|
MinPRole,
|
||||||
@ -394,8 +446,8 @@ public:
|
|||||||
roles[GpuLayersRole] = "gpuLayers";
|
roles[GpuLayersRole] = "gpuLayers";
|
||||||
roles[RepeatPenaltyRole] = "repeatPenalty";
|
roles[RepeatPenaltyRole] = "repeatPenalty";
|
||||||
roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens";
|
roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens";
|
||||||
roles[PromptTemplateRole] = "promptTemplate";
|
roles[ChatTemplateRole] = "chatTemplate";
|
||||||
roles[SystemPromptRole] = "systemPrompt";
|
roles[SystemMessageRole] = "systemMessage";
|
||||||
roles[ChatNamePromptRole] = "chatNamePrompt";
|
roles[ChatNamePromptRole] = "chatNamePrompt";
|
||||||
roles[SuggestedFollowUpPromptRole] = "suggestedFollowUpPrompt";
|
roles[SuggestedFollowUpPromptRole] = "suggestedFollowUpPrompt";
|
||||||
roles[LikesRole] = "likes";
|
roles[LikesRole] = "likes";
|
||||||
@ -416,7 +468,7 @@ public:
|
|||||||
bool contains(const QString &id) const;
|
bool contains(const QString &id) const;
|
||||||
bool containsByFilename(const QString &filename) const;
|
bool containsByFilename(const QString &filename) const;
|
||||||
Q_INVOKABLE ModelInfo modelInfo(const QString &id) const;
|
Q_INVOKABLE ModelInfo modelInfo(const QString &id) const;
|
||||||
Q_INVOKABLE ModelInfo modelInfoByFilename(const QString &filename) const;
|
Q_INVOKABLE ModelInfo modelInfoByFilename(const QString &filename, bool allowClone = true) const;
|
||||||
Q_INVOKABLE bool isUniqueName(const QString &name) const;
|
Q_INVOKABLE bool isUniqueName(const QString &name) const;
|
||||||
Q_INVOKABLE QString clone(const ModelInfo &model);
|
Q_INVOKABLE QString clone(const ModelInfo &model);
|
||||||
Q_INVOKABLE void removeClone(const ModelInfo &model);
|
Q_INVOKABLE void removeClone(const ModelInfo &model);
|
||||||
@ -476,15 +528,18 @@ Q_SIGNALS:
|
|||||||
void discoverSortChanged();
|
void discoverSortChanged();
|
||||||
void discoverProgressChanged();
|
void discoverProgressChanged();
|
||||||
void discoverInProgressChanged();
|
void discoverInProgressChanged();
|
||||||
|
void modelInfoChanged(const ModelInfo &info);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool eventFilter(QObject *obj, QEvent *ev) override;
|
bool eventFilter(QObject *obj, QEvent *ev) override;
|
||||||
|
|
||||||
private Q_SLOTS:
|
private Q_SLOTS:
|
||||||
|
void onDataChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList<int> &roles);
|
||||||
void resortModel();
|
void resortModel();
|
||||||
void updateModelsFromJson();
|
void updateModelsFromJson();
|
||||||
void updateModelsFromJsonAsync();
|
void updateModelsFromJsonAsync();
|
||||||
void updateModelsFromSettings();
|
void updateModelsFromSettings();
|
||||||
|
void maybeUpdateDataForSettings(const ModelInfo &info, bool fromInfo);
|
||||||
void updateDataForSettings();
|
void updateDataForSettings();
|
||||||
void handleModelsJsonDownloadFinished();
|
void handleModelsJsonDownloadFinished();
|
||||||
void handleModelsJsonDownloadErrorOccurred(QNetworkReply::NetworkError code);
|
void handleModelsJsonDownloadErrorOccurred(QNetworkReply::NetworkError code);
|
||||||
@ -495,6 +550,9 @@ private Q_SLOTS:
|
|||||||
void handleSslErrors(QNetworkReply *reply, const QList<QSslError> &errors);
|
void handleSslErrors(QNetworkReply *reply, const QList<QSslError> &errors);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// Return the index of the model with the given id, or -1 if not found.
|
||||||
|
int indexByModelId(const QString &id) const;
|
||||||
|
|
||||||
void removeInternal(const ModelInfo &model);
|
void removeInternal(const ModelInfo &model);
|
||||||
void clearDiscoveredModels();
|
void clearDiscoveredModels();
|
||||||
bool modelExists(const QString &fileName) const;
|
bool modelExists(const QString &fileName) const;
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
#include "mysettings.h"
|
#include "mysettings.h"
|
||||||
|
|
||||||
|
#include "chatllm.h"
|
||||||
|
#include "modellist.h"
|
||||||
|
|
||||||
#include <gpt4all-backend/llmodel.h>
|
#include <gpt4all-backend/llmodel.h>
|
||||||
|
|
||||||
#include <QDebug>
|
#include <QDebug>
|
||||||
@ -29,8 +32,13 @@ static const QStringList suggestionModeNames { "LocalDocsOnly", "On", "Off" };
|
|||||||
static const QStringList chatThemeNames { "Light", "Dark", "LegacyDark" };
|
static const QStringList chatThemeNames { "Light", "Dark", "LegacyDark" };
|
||||||
static const QStringList fontSizeNames { "Small", "Medium", "Large" };
|
static const QStringList fontSizeNames { "Small", "Medium", "Large" };
|
||||||
|
|
||||||
// FIXME: All of these default strings that are shown in the UI for settings need to be marked as
|
// psuedo-enum
|
||||||
// translatable
|
namespace ModelSettingsKey { namespace {
|
||||||
|
auto ChatTemplate = "chatTemplate"_L1;
|
||||||
|
auto PromptTemplate = "promptTemplate"_L1; // legacy
|
||||||
|
auto SystemMessage = "systemMessage"_L1;
|
||||||
|
auto SystemPrompt = "systemPrompt"_L1; // legacy
|
||||||
|
} } // namespace ModelSettingsKey::(anonymous)
|
||||||
|
|
||||||
namespace defaults {
|
namespace defaults {
|
||||||
|
|
||||||
@ -48,7 +56,6 @@ static const QVariantMap basicDefaults {
|
|||||||
{ "fontSize", QVariant::fromValue(FontSize::Small) },
|
{ "fontSize", QVariant::fromValue(FontSize::Small) },
|
||||||
{ "lastVersionStarted", "" },
|
{ "lastVersionStarted", "" },
|
||||||
{ "networkPort", 4891, },
|
{ "networkPort", 4891, },
|
||||||
{ "saveChatsContext", false },
|
|
||||||
{ "systemTray", false },
|
{ "systemTray", false },
|
||||||
{ "serverChat", false },
|
{ "serverChat", false },
|
||||||
{ "userDefaultModel", "Application default" },
|
{ "userDefaultModel", "Application default" },
|
||||||
@ -147,6 +154,11 @@ static QStringList getUiLanguages(const QString &modelPath)
|
|||||||
return languageList;
|
return languageList;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static QString modelSettingName(const ModelInfo &info, auto &&name)
|
||||||
|
{
|
||||||
|
return u"model-%1/%2"_s.arg(info.id(), name);
|
||||||
|
}
|
||||||
|
|
||||||
class MyPrivateSettings: public MySettings { };
|
class MyPrivateSettings: public MySettings { };
|
||||||
Q_GLOBAL_STATIC(MyPrivateSettings, settingsInstance)
|
Q_GLOBAL_STATIC(MyPrivateSettings, settingsInstance)
|
||||||
MySettings *MySettings::globalInstance()
|
MySettings *MySettings::globalInstance()
|
||||||
@ -162,6 +174,34 @@ MySettings::MySettings()
|
|||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
QVariant MySettings::checkJinjaTemplateError(const QString &tmpl)
|
||||||
|
{
|
||||||
|
if (auto err = ChatLLM::checkJinjaTemplateError(tmpl.toStdString()))
|
||||||
|
return QString::fromStdString(*err);
|
||||||
|
return QVariant::fromValue(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unset settings come from ModelInfo. Listen for changes so we can emit our own setting-specific signals.
|
||||||
|
void MySettings::onModelInfoChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList<int> &roles)
|
||||||
|
{
|
||||||
|
auto settingChanged = [&](const auto &info, auto role, const auto &name) {
|
||||||
|
return (roles.isEmpty() || roles.contains(role)) && !m_settings.contains(modelSettingName(info, name));
|
||||||
|
};
|
||||||
|
|
||||||
|
auto &modelList = dynamic_cast<const ModelList &>(*QObject::sender());
|
||||||
|
for (int row = topLeft.row(); row <= bottomRight.row(); row++) {
|
||||||
|
using enum ModelList::Roles;
|
||||||
|
using namespace ModelSettingsKey;
|
||||||
|
auto index = topLeft.siblingAtRow(row);
|
||||||
|
if (auto info = modelList.modelInfo(index.data(IdRole).toString()); !info.id().isNull()) {
|
||||||
|
if (settingChanged(info, ChatTemplateRole, ChatTemplate))
|
||||||
|
emit chatTemplateChanged(info, /*fromInfo*/ true);
|
||||||
|
if (settingChanged(info, SystemMessageRole, SystemMessage))
|
||||||
|
emit systemMessageChanged(info, /*fromInfo*/ true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
QVariant MySettings::getBasicSetting(const QString &name) const
|
QVariant MySettings::getBasicSetting(const QString &name) const
|
||||||
{
|
{
|
||||||
return m_settings.value(name, basicDefaults.value(name));
|
return m_settings.value(name, basicDefaults.value(name));
|
||||||
@ -194,8 +234,8 @@ void MySettings::restoreModelDefaults(const ModelInfo &info)
|
|||||||
setModelGpuLayers(info, info.m_gpuLayers);
|
setModelGpuLayers(info, info.m_gpuLayers);
|
||||||
setModelRepeatPenalty(info, info.m_repeatPenalty);
|
setModelRepeatPenalty(info, info.m_repeatPenalty);
|
||||||
setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens);
|
setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens);
|
||||||
setModelPromptTemplate(info, info.m_promptTemplate);
|
resetModelChatTemplate (info);
|
||||||
setModelSystemPrompt(info, info.m_systemPrompt);
|
resetModelSystemMessage(info);
|
||||||
setModelChatNamePrompt(info, info.m_chatNamePrompt);
|
setModelChatNamePrompt(info, info.m_chatNamePrompt);
|
||||||
setModelSuggestedFollowUpPrompt(info, info.m_suggestedFollowUpPrompt);
|
setModelSuggestedFollowUpPrompt(info, info.m_suggestedFollowUpPrompt);
|
||||||
}
|
}
|
||||||
@ -206,7 +246,6 @@ void MySettings::restoreApplicationDefaults()
|
|||||||
setFontSize(basicDefaults.value("fontSize").value<FontSize>());
|
setFontSize(basicDefaults.value("fontSize").value<FontSize>());
|
||||||
setDevice(defaults::device);
|
setDevice(defaults::device);
|
||||||
setThreadCount(defaults::threadCount);
|
setThreadCount(defaults::threadCount);
|
||||||
setSaveChatsContext(basicDefaults.value("saveChatsContext").toBool());
|
|
||||||
setSystemTray(basicDefaults.value("systemTray").toBool());
|
setSystemTray(basicDefaults.value("systemTray").toBool());
|
||||||
setServerChat(basicDefaults.value("serverChat").toBool());
|
setServerChat(basicDefaults.value("serverChat").toBool());
|
||||||
setNetworkPort(basicDefaults.value("networkPort").toInt());
|
setNetworkPort(basicDefaults.value("networkPort").toInt());
|
||||||
@ -252,29 +291,37 @@ void MySettings::setModelName(const ModelInfo &info, const QString &value, bool
|
|||||||
emit nameChanged(info);
|
emit nameChanged(info);
|
||||||
}
|
}
|
||||||
|
|
||||||
static QString modelSettingName(const ModelInfo &info, const QString &name)
|
QVariant MySettings::getModelSetting(QLatin1StringView name, const ModelInfo &info) const
|
||||||
{
|
{
|
||||||
return u"model-%1/%2"_s.arg(info.id(), name);
|
QLatin1StringView nameL1(name);
|
||||||
|
return m_settings.value(modelSettingName(info, nameL1), info.getField(nameL1));
|
||||||
}
|
}
|
||||||
|
|
||||||
QVariant MySettings::getModelSetting(const QString &name, const ModelInfo &info) const
|
QVariant MySettings::getModelSetting(const char *name, const ModelInfo &info) const
|
||||||
{
|
{
|
||||||
return m_settings.value(modelSettingName(info, name), info.getFields().value(name));
|
return getModelSetting(QLatin1StringView(name), info);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MySettings::setModelSetting(const QString &name, const ModelInfo &info, const QVariant &value, bool force,
|
void MySettings::setModelSetting(QLatin1StringView name, const ModelInfo &info, const QVariant &value, bool force,
|
||||||
bool signal)
|
bool signal)
|
||||||
{
|
{
|
||||||
if (!force && (info.id().isEmpty() || getModelSetting(name, info) == value))
|
if (!force && (info.id().isEmpty() || getModelSetting(name, info) == value))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
QString settingName = modelSettingName(info, name);
|
QLatin1StringView nameL1(name);
|
||||||
if (info.getFields().value(name) == value && !info.shouldSaveMetadata())
|
QString settingName = modelSettingName(info, nameL1);
|
||||||
|
if (info.getField(nameL1) == value && !info.shouldSaveMetadata())
|
||||||
m_settings.remove(settingName);
|
m_settings.remove(settingName);
|
||||||
else
|
else
|
||||||
m_settings.setValue(settingName, value);
|
m_settings.setValue(settingName, value);
|
||||||
if (signal && !force)
|
if (signal && !force)
|
||||||
QMetaObject::invokeMethod(this, u"%1Changed"_s.arg(name).toLatin1().constData(), Q_ARG(ModelInfo, info));
|
QMetaObject::invokeMethod(this, u"%1Changed"_s.arg(nameL1).toLatin1().constData(), Q_ARG(ModelInfo, info));
|
||||||
|
}
|
||||||
|
|
||||||
|
void MySettings::setModelSetting(const char *name, const ModelInfo &info, const QVariant &value, bool force,
|
||||||
|
bool signal)
|
||||||
|
{
|
||||||
|
setModelSetting(QLatin1StringView(name), info, value, force, signal);
|
||||||
}
|
}
|
||||||
|
|
||||||
QString MySettings::modelFilename (const ModelInfo &info) const { return getModelSetting("filename", info).toString(); }
|
QString MySettings::modelFilename (const ModelInfo &info) const { return getModelSetting("filename", info).toString(); }
|
||||||
@ -297,11 +344,68 @@ int MySettings::modelContextLength (const ModelInfo &info) const
|
|||||||
int MySettings::modelGpuLayers (const ModelInfo &info) const { return getModelSetting("gpuLayers", info).toInt(); }
|
int MySettings::modelGpuLayers (const ModelInfo &info) const { return getModelSetting("gpuLayers", info).toInt(); }
|
||||||
double MySettings::modelRepeatPenalty (const ModelInfo &info) const { return getModelSetting("repeatPenalty", info).toDouble(); }
|
double MySettings::modelRepeatPenalty (const ModelInfo &info) const { return getModelSetting("repeatPenalty", info).toDouble(); }
|
||||||
int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); }
|
int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); }
|
||||||
QString MySettings::modelPromptTemplate (const ModelInfo &info) const { return getModelSetting("promptTemplate", info).toString(); }
|
|
||||||
QString MySettings::modelSystemPrompt (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); }
|
|
||||||
QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); }
|
QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); }
|
||||||
QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); }
|
QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); }
|
||||||
|
|
||||||
|
auto MySettings::getUpgradeableModelSetting(
|
||||||
|
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
|
||||||
|
) const -> UpgradeableSetting
|
||||||
|
{
|
||||||
|
if (info.id().isEmpty()) {
|
||||||
|
qWarning("%s: got null model", Q_FUNC_INFO);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto value = m_settings.value(modelSettingName(info, legacyKey));
|
||||||
|
if (value.isValid())
|
||||||
|
return { UpgradeableSetting::legacy_tag, value.toString() };
|
||||||
|
|
||||||
|
value = getModelSetting(newKey, info);
|
||||||
|
if (!value.isNull())
|
||||||
|
return value.toString();
|
||||||
|
return {}; // neither a default nor an override
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MySettings::isUpgradeableModelSettingSet(
|
||||||
|
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
|
||||||
|
) const
|
||||||
|
{
|
||||||
|
if (info.id().isEmpty()) {
|
||||||
|
qWarning("%s: got null model", Q_FUNC_INFO);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (m_settings.contains(modelSettingName(info, legacyKey)))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
// NOTE: unlike getUpgradeableSetting(), this ignores the default
|
||||||
|
return m_settings.contains(modelSettingName(info, newKey));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto MySettings::modelChatTemplate(const ModelInfo &info) const -> UpgradeableSetting
|
||||||
|
{
|
||||||
|
using namespace ModelSettingsKey;
|
||||||
|
return getUpgradeableModelSetting(info, PromptTemplate, ChatTemplate);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MySettings::isModelChatTemplateSet(const ModelInfo &info) const
|
||||||
|
{
|
||||||
|
using namespace ModelSettingsKey;
|
||||||
|
return isUpgradeableModelSettingSet(info, PromptTemplate, ChatTemplate);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto MySettings::modelSystemMessage(const ModelInfo &info) const -> UpgradeableSetting
|
||||||
|
{
|
||||||
|
using namespace ModelSettingsKey;
|
||||||
|
return getUpgradeableModelSetting(info, SystemPrompt, SystemMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MySettings::isModelSystemMessageSet(const ModelInfo &info) const
|
||||||
|
{
|
||||||
|
using namespace ModelSettingsKey;
|
||||||
|
return isUpgradeableModelSettingSet(info, SystemPrompt, SystemMessage);
|
||||||
|
}
|
||||||
|
|
||||||
void MySettings::setModelFilename(const ModelInfo &info, const QString &value, bool force)
|
void MySettings::setModelFilename(const ModelInfo &info, const QString &value, bool force)
|
||||||
{
|
{
|
||||||
setModelSetting("filename", info, value, force, true);
|
setModelSetting("filename", info, value, force, true);
|
||||||
@ -402,14 +506,77 @@ void MySettings::setModelRepeatPenaltyTokens(const ModelInfo &info, int value, b
|
|||||||
setModelSetting("repeatPenaltyTokens", info, value, force, true);
|
setModelSetting("repeatPenaltyTokens", info, value, force, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MySettings::setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force)
|
bool MySettings::setUpgradeableModelSetting(
|
||||||
{
|
const ModelInfo &info, const QString &value, QLatin1StringView legacyKey, QLatin1StringView newKey
|
||||||
setModelSetting("promptTemplate", info, value, force, true);
|
) {
|
||||||
|
if (info.id().isEmpty()) {
|
||||||
|
qWarning("%s: got null model", Q_FUNC_INFO);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto legacyModelKey = modelSettingName(info, legacyKey);
|
||||||
|
auto newModelKey = modelSettingName(info, newKey );
|
||||||
|
bool changed = false;
|
||||||
|
if (m_settings.contains(legacyModelKey)) {
|
||||||
|
m_settings.remove(legacyModelKey);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
auto oldValue = m_settings.value(newModelKey);
|
||||||
|
if (!oldValue.isValid() || oldValue.toString() != value) {
|
||||||
|
m_settings.setValue(newModelKey, value);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MySettings::setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force)
|
bool MySettings::resetUpgradeableModelSetting(
|
||||||
|
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
|
||||||
|
) {
|
||||||
|
if (info.id().isEmpty()) {
|
||||||
|
qWarning("%s: got null model", Q_FUNC_INFO);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto legacyModelKey = modelSettingName(info, legacyKey);
|
||||||
|
auto newModelKey = modelSettingName(info, newKey );
|
||||||
|
bool changed = false;
|
||||||
|
if (m_settings.contains(legacyModelKey)) {
|
||||||
|
m_settings.remove(legacyModelKey);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
if (m_settings.contains(newModelKey)) {
|
||||||
|
m_settings.remove(newModelKey);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MySettings::setModelChatTemplate(const ModelInfo &info, const QString &value)
|
||||||
{
|
{
|
||||||
setModelSetting("systemPrompt", info, value, force, true);
|
using namespace ModelSettingsKey;
|
||||||
|
if (setUpgradeableModelSetting(info, value, PromptTemplate, ChatTemplate))
|
||||||
|
emit chatTemplateChanged(info);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MySettings::resetModelChatTemplate(const ModelInfo &info)
|
||||||
|
{
|
||||||
|
using namespace ModelSettingsKey;
|
||||||
|
if (resetUpgradeableModelSetting(info, PromptTemplate, ChatTemplate))
|
||||||
|
emit chatTemplateChanged(info);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MySettings::setModelSystemMessage(const ModelInfo &info, const QString &value)
|
||||||
|
{
|
||||||
|
using namespace ModelSettingsKey;
|
||||||
|
if (setUpgradeableModelSetting(info, value, SystemPrompt, SystemMessage))
|
||||||
|
emit systemMessageChanged(info);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MySettings::resetModelSystemMessage(const ModelInfo &info)
|
||||||
|
{
|
||||||
|
using namespace ModelSettingsKey;
|
||||||
|
if (resetUpgradeableModelSetting(info, SystemPrompt, SystemMessage))
|
||||||
|
emit systemMessageChanged(info);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MySettings::setModelChatNamePrompt(const ModelInfo &info, const QString &value, bool force)
|
void MySettings::setModelChatNamePrompt(const ModelInfo &info, const QString &value, bool force)
|
||||||
@ -445,7 +612,6 @@ void MySettings::setThreadCount(int value)
|
|||||||
emit threadCountChanged();
|
emit threadCountChanged();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MySettings::saveChatsContext() const { return getBasicSetting("saveChatsContext" ).toBool(); }
|
|
||||||
bool MySettings::systemTray() const { return getBasicSetting("systemTray" ).toBool(); }
|
bool MySettings::systemTray() const { return getBasicSetting("systemTray" ).toBool(); }
|
||||||
bool MySettings::serverChat() const { return getBasicSetting("serverChat" ).toBool(); }
|
bool MySettings::serverChat() const { return getBasicSetting("serverChat" ).toBool(); }
|
||||||
int MySettings::networkPort() const { return getBasicSetting("networkPort" ).toInt(); }
|
int MySettings::networkPort() const { return getBasicSetting("networkPort" ).toInt(); }
|
||||||
@ -464,7 +630,6 @@ ChatTheme MySettings::chatTheme() const { return ChatTheme (getEnu
|
|||||||
FontSize MySettings::fontSize() const { return FontSize (getEnumSetting("fontSize", fontSizeNames)); }
|
FontSize MySettings::fontSize() const { return FontSize (getEnumSetting("fontSize", fontSizeNames)); }
|
||||||
SuggestionMode MySettings::suggestionMode() const { return SuggestionMode(getEnumSetting("suggestionMode", suggestionModeNames)); }
|
SuggestionMode MySettings::suggestionMode() const { return SuggestionMode(getEnumSetting("suggestionMode", suggestionModeNames)); }
|
||||||
|
|
||||||
void MySettings::setSaveChatsContext(bool value) { setBasicSetting("saveChatsContext", value); }
|
|
||||||
void MySettings::setSystemTray(bool value) { setBasicSetting("systemTray", value); }
|
void MySettings::setSystemTray(bool value) { setBasicSetting("systemTray", value); }
|
||||||
void MySettings::setServerChat(bool value) { setBasicSetting("serverChat", value); }
|
void MySettings::setServerChat(bool value) { setBasicSetting("serverChat", value); }
|
||||||
void MySettings::setNetworkPort(int value) { setBasicSetting("networkPort", value); }
|
void MySettings::setNetworkPort(int value) { setBasicSetting("networkPort", value); }
|
||||||
|
@ -4,6 +4,9 @@
|
|||||||
#include "modellist.h" // IWYU pragma: keep
|
#include "modellist.h" // IWYU pragma: keep
|
||||||
|
|
||||||
#include <QDateTime>
|
#include <QDateTime>
|
||||||
|
#include <QLatin1StringView>
|
||||||
|
#include <QList>
|
||||||
|
#include <QModelIndex>
|
||||||
#include <QObject>
|
#include <QObject>
|
||||||
#include <QSettings>
|
#include <QSettings>
|
||||||
#include <QString>
|
#include <QString>
|
||||||
@ -48,7 +51,6 @@ class MySettings : public QObject
|
|||||||
{
|
{
|
||||||
Q_OBJECT
|
Q_OBJECT
|
||||||
Q_PROPERTY(int threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
Q_PROPERTY(int threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||||
Q_PROPERTY(bool saveChatsContext READ saveChatsContext WRITE setSaveChatsContext NOTIFY saveChatsContextChanged)
|
|
||||||
Q_PROPERTY(bool systemTray READ systemTray WRITE setSystemTray NOTIFY systemTrayChanged)
|
Q_PROPERTY(bool systemTray READ systemTray WRITE setSystemTray NOTIFY systemTrayChanged)
|
||||||
Q_PROPERTY(bool serverChat READ serverChat WRITE setServerChat NOTIFY serverChatChanged)
|
Q_PROPERTY(bool serverChat READ serverChat WRITE setServerChat NOTIFY serverChatChanged)
|
||||||
Q_PROPERTY(QString modelPath READ modelPath WRITE setModelPath NOTIFY modelPathChanged)
|
Q_PROPERTY(QString modelPath READ modelPath WRITE setModelPath NOTIFY modelPathChanged)
|
||||||
@ -75,9 +77,18 @@ class MySettings : public QObject
|
|||||||
Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged)
|
Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged)
|
||||||
Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT)
|
Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT)
|
||||||
|
|
||||||
|
private:
|
||||||
|
explicit MySettings();
|
||||||
|
~MySettings() override = default;
|
||||||
|
|
||||||
|
public Q_SLOTS:
|
||||||
|
void onModelInfoChanged(const QModelIndex &topLeft, const QModelIndex &bottomRight, const QList<int> &roles = {});
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static MySettings *globalInstance();
|
static MySettings *globalInstance();
|
||||||
|
|
||||||
|
Q_INVOKABLE static QVariant checkJinjaTemplateError(const QString &tmpl);
|
||||||
|
|
||||||
// Restore methods
|
// Restore methods
|
||||||
Q_INVOKABLE void restoreModelDefaults(const ModelInfo &info);
|
Q_INVOKABLE void restoreModelDefaults(const ModelInfo &info);
|
||||||
Q_INVOKABLE void restoreApplicationDefaults();
|
Q_INVOKABLE void restoreApplicationDefaults();
|
||||||
@ -125,10 +136,14 @@ public:
|
|||||||
Q_INVOKABLE void setModelRepeatPenalty(const ModelInfo &info, double value, bool force = false);
|
Q_INVOKABLE void setModelRepeatPenalty(const ModelInfo &info, double value, bool force = false);
|
||||||
int modelRepeatPenaltyTokens(const ModelInfo &info) const;
|
int modelRepeatPenaltyTokens(const ModelInfo &info) const;
|
||||||
Q_INVOKABLE void setModelRepeatPenaltyTokens(const ModelInfo &info, int value, bool force = false);
|
Q_INVOKABLE void setModelRepeatPenaltyTokens(const ModelInfo &info, int value, bool force = false);
|
||||||
QString modelPromptTemplate(const ModelInfo &info) const;
|
auto modelChatTemplate(const ModelInfo &info) const -> UpgradeableSetting;
|
||||||
Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force = false);
|
Q_INVOKABLE bool isModelChatTemplateSet(const ModelInfo &info) const;
|
||||||
QString modelSystemPrompt(const ModelInfo &info) const;
|
Q_INVOKABLE void setModelChatTemplate(const ModelInfo &info, const QString &value);
|
||||||
Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force = false);
|
Q_INVOKABLE void resetModelChatTemplate(const ModelInfo &info);
|
||||||
|
auto modelSystemMessage(const ModelInfo &info) const -> UpgradeableSetting;
|
||||||
|
Q_INVOKABLE bool isModelSystemMessageSet(const ModelInfo &info) const;
|
||||||
|
Q_INVOKABLE void setModelSystemMessage(const ModelInfo &info, const QString &value);
|
||||||
|
Q_INVOKABLE void resetModelSystemMessage(const ModelInfo &info);
|
||||||
int modelContextLength(const ModelInfo &info) const;
|
int modelContextLength(const ModelInfo &info) const;
|
||||||
Q_INVOKABLE void setModelContextLength(const ModelInfo &info, int value, bool force = false);
|
Q_INVOKABLE void setModelContextLength(const ModelInfo &info, int value, bool force = false);
|
||||||
int modelGpuLayers(const ModelInfo &info) const;
|
int modelGpuLayers(const ModelInfo &info) const;
|
||||||
@ -141,8 +156,6 @@ public:
|
|||||||
// Application settings
|
// Application settings
|
||||||
int threadCount() const;
|
int threadCount() const;
|
||||||
void setThreadCount(int value);
|
void setThreadCount(int value);
|
||||||
bool saveChatsContext() const;
|
|
||||||
void setSaveChatsContext(bool value);
|
|
||||||
bool systemTray() const;
|
bool systemTray() const;
|
||||||
void setSystemTray(bool value);
|
void setSystemTray(bool value);
|
||||||
bool serverChat() const;
|
bool serverChat() const;
|
||||||
@ -215,12 +228,11 @@ Q_SIGNALS:
|
|||||||
void gpuLayersChanged(const ModelInfo &info);
|
void gpuLayersChanged(const ModelInfo &info);
|
||||||
void repeatPenaltyChanged(const ModelInfo &info);
|
void repeatPenaltyChanged(const ModelInfo &info);
|
||||||
void repeatPenaltyTokensChanged(const ModelInfo &info);
|
void repeatPenaltyTokensChanged(const ModelInfo &info);
|
||||||
void promptTemplateChanged(const ModelInfo &info);
|
void chatTemplateChanged(const ModelInfo &info, bool fromInfo = false);
|
||||||
void systemPromptChanged(const ModelInfo &info);
|
void systemMessageChanged(const ModelInfo &info, bool fromInfo = false);
|
||||||
void chatNamePromptChanged(const ModelInfo &info);
|
void chatNamePromptChanged(const ModelInfo &info);
|
||||||
void suggestedFollowUpPromptChanged(const ModelInfo &info);
|
void suggestedFollowUpPromptChanged(const ModelInfo &info);
|
||||||
void threadCountChanged();
|
void threadCountChanged();
|
||||||
void saveChatsContextChanged();
|
|
||||||
void systemTrayChanged();
|
void systemTrayChanged();
|
||||||
void serverChatChanged();
|
void serverChatChanged();
|
||||||
void modelPathChanged();
|
void modelPathChanged();
|
||||||
@ -245,6 +257,30 @@ Q_SIGNALS:
|
|||||||
void suggestionModeChanged();
|
void suggestionModeChanged();
|
||||||
void languageAndLocaleChanged();
|
void languageAndLocaleChanged();
|
||||||
|
|
||||||
|
private:
|
||||||
|
QVariant getBasicSetting(const QString &name) const;
|
||||||
|
void setBasicSetting(const QString &name, const QVariant &value, std::optional<QString> signal = std::nullopt);
|
||||||
|
int getEnumSetting(const QString &setting, const QStringList &valueNames) const;
|
||||||
|
QVariant getModelSetting(QLatin1StringView name, const ModelInfo &info) const;
|
||||||
|
QVariant getModelSetting(const char *name, const ModelInfo &info) const;
|
||||||
|
void setModelSetting(QLatin1StringView name, const ModelInfo &info, const QVariant &value, bool force,
|
||||||
|
bool signal = false);
|
||||||
|
void setModelSetting(const char *name, const ModelInfo &info, const QVariant &value, bool force,
|
||||||
|
bool signal = false);
|
||||||
|
auto getUpgradeableModelSetting(
|
||||||
|
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
|
||||||
|
) const -> UpgradeableSetting;
|
||||||
|
bool isUpgradeableModelSettingSet(
|
||||||
|
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
|
||||||
|
) const;
|
||||||
|
bool setUpgradeableModelSetting(
|
||||||
|
const ModelInfo &info, const QString &value, QLatin1StringView legacyKey, QLatin1StringView newKey
|
||||||
|
);
|
||||||
|
bool resetUpgradeableModelSetting(
|
||||||
|
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
|
||||||
|
);
|
||||||
|
QString filePathForLocale(const QLocale &locale);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
QSettings m_settings;
|
QSettings m_settings;
|
||||||
bool m_forceMetal;
|
bool m_forceMetal;
|
||||||
@ -253,18 +289,7 @@ private:
|
|||||||
const QStringList m_uiLanguages;
|
const QStringList m_uiLanguages;
|
||||||
std::unique_ptr<QTranslator> m_translator;
|
std::unique_ptr<QTranslator> m_translator;
|
||||||
|
|
||||||
private:
|
|
||||||
explicit MySettings();
|
|
||||||
~MySettings() {}
|
|
||||||
friend class MyPrivateSettings;
|
friend class MyPrivateSettings;
|
||||||
|
|
||||||
QVariant getBasicSetting(const QString &name) const;
|
|
||||||
void setBasicSetting(const QString &name, const QVariant &value, std::optional<QString> signal = std::nullopt);
|
|
||||||
int getEnumSetting(const QString &setting, const QStringList &valueNames) const;
|
|
||||||
QVariant getModelSetting(const QString &name, const ModelInfo &info) const;
|
|
||||||
void setModelSetting(const QString &name, const ModelInfo &info, const QVariant &value, bool force,
|
|
||||||
bool signal = false);
|
|
||||||
QString filePathForLocale(const QLocale &locale);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // MYSETTINGS_H
|
#endif // MYSETTINGS_H
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include "localdocsmodel.h"
|
#include "localdocsmodel.h"
|
||||||
#include "modellist.h"
|
#include "modellist.h"
|
||||||
#include "mysettings.h"
|
#include "mysettings.h"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
#include <gpt4all-backend/llmodel.h>
|
#include <gpt4all-backend/llmodel.h>
|
||||||
|
|
||||||
@ -192,11 +193,14 @@ bool Network::packageAndSendJson(const QString &ingestId, const QString &json)
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto *currentChat = ChatListModel::globalInstance()->currentChat();
|
||||||
|
Q_ASSERT(currentChat);
|
||||||
|
auto modelInfo = currentChat->modelInfo();
|
||||||
|
|
||||||
Q_ASSERT(doc.isObject());
|
Q_ASSERT(doc.isObject());
|
||||||
Q_ASSERT(ChatListModel::globalInstance()->currentChat());
|
|
||||||
QJsonObject object = doc.object();
|
QJsonObject object = doc.object();
|
||||||
object.insert("source", "gpt4all-chat");
|
object.insert("source", "gpt4all-chat");
|
||||||
object.insert("agent_id", ChatListModel::globalInstance()->currentChat()->modelInfo().filename());
|
object.insert("agent_id", modelInfo.filename());
|
||||||
object.insert("submitter_id", m_uniqueId);
|
object.insert("submitter_id", m_uniqueId);
|
||||||
object.insert("ingest_id", ingestId);
|
object.insert("ingest_id", ingestId);
|
||||||
|
|
||||||
@ -204,8 +208,9 @@ bool Network::packageAndSendJson(const QString &ingestId, const QString &json)
|
|||||||
if (!attribution.isEmpty())
|
if (!attribution.isEmpty())
|
||||||
object.insert("network/attribution", attribution);
|
object.insert("network/attribution", attribution);
|
||||||
|
|
||||||
QString promptTemplate = ChatListModel::globalInstance()->currentChat()->modelInfo().promptTemplate();
|
if (!modelInfo.id().isNull())
|
||||||
object.insert("prompt_template", promptTemplate);
|
if (auto tmpl = modelInfo.chatTemplate().asModern())
|
||||||
|
object.insert("chat_template"_L1, *tmpl);
|
||||||
|
|
||||||
QJsonDocument newDoc;
|
QJsonDocument newDoc;
|
||||||
newDoc.setObject(object);
|
newDoc.setObject(object);
|
||||||
@ -358,7 +363,8 @@ void Network::sendStartup()
|
|||||||
|
|
||||||
void Network::trackChatEvent(const QString &ev, QVariantMap props)
|
void Network::trackChatEvent(const QString &ev, QVariantMap props)
|
||||||
{
|
{
|
||||||
const auto &curChat = ChatListModel::globalInstance()->currentChat();
|
auto *curChat = ChatListModel::globalInstance()->currentChat();
|
||||||
|
Q_ASSERT(curChat);
|
||||||
if (!props.contains("model"))
|
if (!props.contains("model"))
|
||||||
props.insert("model", curChat->modelInfo().filename());
|
props.insert("model", curChat->modelInfo().filename());
|
||||||
props.insert("device_backend", curChat->deviceBackend());
|
props.insert("device_backend", curChat->deviceBackend());
|
||||||
@ -366,7 +372,7 @@ void Network::trackChatEvent(const QString &ev, QVariantMap props)
|
|||||||
props.insert("doc_collections_enabled", curChat->collectionList().count());
|
props.insert("doc_collections_enabled", curChat->collectionList().count());
|
||||||
props.insert("doc_collections_total", LocalDocs::globalInstance()->localDocsModel()->rowCount());
|
props.insert("doc_collections_total", LocalDocs::globalInstance()->localDocsModel()->rowCount());
|
||||||
props.insert("datalake_active", MySettings::globalInstance()->networkIsActive());
|
props.insert("datalake_active", MySettings::globalInstance()->networkIsActive());
|
||||||
props.insert("using_server", ChatListModel::globalInstance()->currentChat()->isServer());
|
props.insert("using_server", curChat->isServer());
|
||||||
trackEvent(ev, props);
|
trackEvent(ev, props);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -313,11 +313,8 @@ const std::unordered_map<BaseCompletionRequest::Type, const char *> BaseCompleti
|
|||||||
class ChatRequest : public BaseCompletionRequest {
|
class ChatRequest : public BaseCompletionRequest {
|
||||||
public:
|
public:
|
||||||
struct Message {
|
struct Message {
|
||||||
enum class Role : uint8_t {
|
enum class Role { System, User, Assistant };
|
||||||
User,
|
Role role;
|
||||||
Assistant,
|
|
||||||
};
|
|
||||||
Role role;
|
|
||||||
QString content;
|
QString content;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -349,7 +346,6 @@ protected:
|
|||||||
this->messages.clear();
|
this->messages.clear();
|
||||||
{
|
{
|
||||||
QCborArray arr = value.toArray();
|
QCborArray arr = value.toArray();
|
||||||
Message::Role nextRole = Message::Role::User;
|
|
||||||
for (qsizetype i = 0; i < arr.size(); i++) {
|
for (qsizetype i = 0; i < arr.size(); i++) {
|
||||||
const auto &elem = arr[i];
|
const auto &elem = arr[i];
|
||||||
if (!elem.isMap())
|
if (!elem.isMap())
|
||||||
@ -360,9 +356,9 @@ protected:
|
|||||||
QCborMap msg = elem.toMap();
|
QCborMap msg = elem.toMap();
|
||||||
Message res;
|
Message res;
|
||||||
QString role = takeValue(msg, "role", String, /*required*/ true).toString();
|
QString role = takeValue(msg, "role", String, /*required*/ true).toString();
|
||||||
if (role == u"system"_s)
|
if (role == u"system"_s) {
|
||||||
continue; // FIXME(jared): don't ignore these
|
res.role = Message::Role::System;
|
||||||
if (role == u"user"_s) {
|
} else if (role == u"user"_s) {
|
||||||
res.role = Message::Role::User;
|
res.role = Message::Role::User;
|
||||||
} else if (role == u"assistant"_s) {
|
} else if (role == u"assistant"_s) {
|
||||||
res.role = Message::Role::Assistant;
|
res.role = Message::Role::Assistant;
|
||||||
@ -374,13 +370,7 @@ protected:
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
res.content = takeValue(msg, "content", String, /*required*/ true).toString();
|
res.content = takeValue(msg, "content", String, /*required*/ true).toString();
|
||||||
if (res.role != nextRole)
|
|
||||||
throw InvalidRequestError(fmt::format(
|
|
||||||
"Invalid 'messages[{}].role': did not expect '{}' here", i, role
|
|
||||||
));
|
|
||||||
this->messages.append(res);
|
this->messages.append(res);
|
||||||
nextRole = res.role == Message::Role::User ? Message::Role::Assistant
|
|
||||||
: Message::Role::User;
|
|
||||||
|
|
||||||
if (!msg.isEmpty())
|
if (!msg.isEmpty())
|
||||||
throw InvalidRequestError(fmt::format(
|
throw InvalidRequestError(fmt::format(
|
||||||
@ -630,8 +620,7 @@ void Server::start()
|
|||||||
});
|
});
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
connect(this, &Server::requestServerNewPromptResponsePair, m_chat,
|
connect(this, &Server::requestResetResponseState, m_chat, &Chat::resetResponseState, Qt::BlockingQueuedConnection);
|
||||||
&Chat::serverNewPromptResponsePair, Qt::BlockingQueuedConnection);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static auto makeError(auto &&...args) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
|
static auto makeError(auto &&...args) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
|
||||||
@ -642,6 +631,10 @@ static auto makeError(auto &&...args) -> std::pair<QHttpServerResponse, std::opt
|
|||||||
auto Server::handleCompletionRequest(const CompletionRequest &request)
|
auto Server::handleCompletionRequest(const CompletionRequest &request)
|
||||||
-> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
|
-> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
|
||||||
{
|
{
|
||||||
|
Q_ASSERT(m_chatModel);
|
||||||
|
|
||||||
|
auto *mySettings = MySettings::globalInstance();
|
||||||
|
|
||||||
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
|
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
|
||||||
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
|
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
|
||||||
for (const ModelInfo &info : modelList) {
|
for (const ModelInfo &info : modelList) {
|
||||||
@ -654,10 +647,6 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// adds prompt/response items to GUI
|
|
||||||
emit requestServerNewPromptResponsePair(request.prompt); // blocks
|
|
||||||
resetResponse();
|
|
||||||
|
|
||||||
// load the new model if necessary
|
// load the new model if necessary
|
||||||
setShouldBeLoaded(true);
|
setShouldBeLoaded(true);
|
||||||
|
|
||||||
@ -666,47 +655,55 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
|
|||||||
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
emit requestResetResponseState(); // blocks
|
||||||
|
qsizetype prevMsgIndex = m_chatModel->count() - 1;
|
||||||
|
if (prevMsgIndex >= 0)
|
||||||
|
m_chatModel->updateCurrentResponse(prevMsgIndex, false);
|
||||||
|
|
||||||
// NB: this resets the context, regardless of whether this model is already loaded
|
// NB: this resets the context, regardless of whether this model is already loaded
|
||||||
if (!loadModel(modelInfo)) {
|
if (!loadModel(modelInfo)) {
|
||||||
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
|
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
|
||||||
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
|
// add prompt/response items to GUI
|
||||||
const int top_k = modelInfo.topK();
|
m_chatModel->appendPrompt(request.prompt);
|
||||||
const int n_batch = modelInfo.promptBatchSize();
|
m_chatModel->appendResponse(prevMsgIndex + 1);
|
||||||
const auto repeat_penalty = float(modelInfo.repeatPenalty());
|
|
||||||
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
|
|
||||||
|
|
||||||
|
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
|
||||||
|
LLModel::PromptContext promptCtx {
|
||||||
|
.n_predict = request.max_tokens,
|
||||||
|
.top_k = mySettings->modelTopK(modelInfo),
|
||||||
|
.top_p = request.top_p,
|
||||||
|
.min_p = request.min_p,
|
||||||
|
.temp = request.temperature,
|
||||||
|
.n_batch = mySettings->modelPromptBatchSize(modelInfo),
|
||||||
|
.repeat_penalty = float(mySettings->modelRepeatPenalty(modelInfo)),
|
||||||
|
.repeat_last_n = mySettings->modelRepeatPenaltyTokens(modelInfo),
|
||||||
|
};
|
||||||
|
|
||||||
|
auto promptUtf8 = request.prompt.toUtf8();
|
||||||
int promptTokens = 0;
|
int promptTokens = 0;
|
||||||
int responseTokens = 0;
|
int responseTokens = 0;
|
||||||
QList<QPair<QString, QList<ResultInfo>>> responses;
|
QStringList responses;
|
||||||
for (int i = 0; i < request.n; ++i) {
|
for (int i = 0; i < request.n; ++i) {
|
||||||
if (!promptInternal(
|
PromptResult result;
|
||||||
m_collections,
|
try {
|
||||||
request.prompt,
|
result = promptInternal(std::string_view(promptUtf8.cbegin(), promptUtf8.cend()),
|
||||||
/*promptTemplate*/ u"%1"_s,
|
promptCtx,
|
||||||
request.max_tokens,
|
/*usedLocalDocs*/ false);
|
||||||
top_k,
|
} catch (const std::exception &e) {
|
||||||
request.top_p,
|
emit responseChanged(e.what());
|
||||||
request.min_p,
|
emit responseStopped(0);
|
||||||
request.temperature,
|
|
||||||
n_batch,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n)) {
|
|
||||||
|
|
||||||
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
|
|
||||||
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
}
|
}
|
||||||
QString resp = response(/*trim*/ false);
|
QString resp = QString::fromUtf8(result.response);
|
||||||
if (request.echo)
|
if (request.echo)
|
||||||
resp = request.prompt + resp;
|
resp = request.prompt + resp;
|
||||||
responses.append({resp, m_databaseResults});
|
responses << resp;
|
||||||
if (!promptTokens)
|
if (i == 0)
|
||||||
promptTokens = m_promptTokens;
|
promptTokens = result.promptTokens;
|
||||||
responseTokens += m_promptResponseTokens - m_promptTokens;
|
responseTokens += result.responseTokens;
|
||||||
if (i < request.n - 1)
|
|
||||||
resetResponse();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
QJsonObject responseObject {
|
QJsonObject responseObject {
|
||||||
@ -717,25 +714,13 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
|
|||||||
};
|
};
|
||||||
|
|
||||||
QJsonArray choices;
|
QJsonArray choices;
|
||||||
{
|
for (qsizetype i = 0; auto &resp : std::as_const(responses)) {
|
||||||
int index = 0;
|
choices << QJsonObject {
|
||||||
for (const auto &r : responses) {
|
{ "text", resp },
|
||||||
QString result = r.first;
|
{ "index", i++ },
|
||||||
QList<ResultInfo> infos = r.second;
|
{ "logprobs", QJsonValue::Null },
|
||||||
QJsonObject choice {
|
{ "finish_reason", responseTokens == request.max_tokens ? "length" : "stop" },
|
||||||
{ "text", result },
|
};
|
||||||
{ "index", index++ },
|
|
||||||
{ "logprobs", QJsonValue::Null },
|
|
||||||
{ "finish_reason", responseTokens == request.max_tokens ? "length" : "stop" },
|
|
||||||
};
|
|
||||||
if (MySettings::globalInstance()->localDocsShowReferences()) {
|
|
||||||
QJsonArray references;
|
|
||||||
for (const auto &ref : infos)
|
|
||||||
references.append(resultToJson(ref));
|
|
||||||
choice.insert("references", references.isEmpty() ? QJsonValue::Null : QJsonValue(references));
|
|
||||||
}
|
|
||||||
choices.append(choice);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
responseObject.insert("choices", choices);
|
responseObject.insert("choices", choices);
|
||||||
@ -751,6 +736,8 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
|
|||||||
auto Server::handleChatRequest(const ChatRequest &request)
|
auto Server::handleChatRequest(const ChatRequest &request)
|
||||||
-> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
|
-> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
|
||||||
{
|
{
|
||||||
|
auto *mySettings = MySettings::globalInstance();
|
||||||
|
|
||||||
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
|
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
|
||||||
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
|
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
|
||||||
for (const ModelInfo &info : modelList) {
|
for (const ModelInfo &info : modelList) {
|
||||||
@ -771,83 +758,58 @@ auto Server::handleChatRequest(const ChatRequest &request)
|
|||||||
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
emit requestResetResponseState(); // blocks
|
||||||
|
|
||||||
// NB: this resets the context, regardless of whether this model is already loaded
|
// NB: this resets the context, regardless of whether this model is already loaded
|
||||||
if (!loadModel(modelInfo)) {
|
if (!loadModel(modelInfo)) {
|
||||||
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
|
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
|
||||||
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
}
|
}
|
||||||
|
|
||||||
const QString promptTemplate = modelInfo.promptTemplate();
|
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
|
||||||
const int top_k = modelInfo.topK();
|
|
||||||
const int n_batch = modelInfo.promptBatchSize();
|
|
||||||
const auto repeat_penalty = float(modelInfo.repeatPenalty());
|
|
||||||
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
|
|
||||||
|
|
||||||
int promptTokens = 0;
|
Q_ASSERT(!request.messages.isEmpty());
|
||||||
|
|
||||||
|
// adds prompt/response items to GUI
|
||||||
|
std::vector<ChatItem> chatItems;
|
||||||
|
for (auto &message : request.messages) {
|
||||||
|
using enum ChatRequest::Message::Role;
|
||||||
|
switch (message.role) {
|
||||||
|
case System: chatItems.emplace_back(ChatItem::system_tag, message.content); break;
|
||||||
|
case User: chatItems.emplace_back(ChatItem::prompt_tag, message.content); break;
|
||||||
|
case Assistant: chatItems.emplace_back(ChatItem::response_tag, /*currentResponse*/ false); break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m_chatModel->appendResponseWithHistory(chatItems);
|
||||||
|
|
||||||
|
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
|
||||||
|
LLModel::PromptContext promptCtx {
|
||||||
|
.n_predict = request.max_tokens,
|
||||||
|
.top_k = mySettings->modelTopK(modelInfo),
|
||||||
|
.top_p = request.top_p,
|
||||||
|
.min_p = request.min_p,
|
||||||
|
.temp = request.temperature,
|
||||||
|
.n_batch = mySettings->modelPromptBatchSize(modelInfo),
|
||||||
|
.repeat_penalty = float(mySettings->modelRepeatPenalty(modelInfo)),
|
||||||
|
.repeat_last_n = mySettings->modelRepeatPenaltyTokens(modelInfo),
|
||||||
|
};
|
||||||
|
|
||||||
|
int promptTokens = 0;
|
||||||
int responseTokens = 0;
|
int responseTokens = 0;
|
||||||
QList<QPair<QString, QList<ResultInfo>>> responses;
|
QList<QPair<QString, QList<ResultInfo>>> responses;
|
||||||
Q_ASSERT(!request.messages.isEmpty());
|
|
||||||
Q_ASSERT(request.messages.size() % 2 == 1);
|
|
||||||
for (int i = 0; i < request.messages.size() - 2; i += 2) {
|
|
||||||
using enum ChatRequest::Message::Role;
|
|
||||||
auto &user = request.messages[i];
|
|
||||||
auto &assistant = request.messages[i + 1];
|
|
||||||
Q_ASSERT(user.role == User);
|
|
||||||
Q_ASSERT(assistant.role == Assistant);
|
|
||||||
|
|
||||||
// adds prompt/response items to GUI
|
|
||||||
emit requestServerNewPromptResponsePair(user.content); // blocks
|
|
||||||
resetResponse();
|
|
||||||
|
|
||||||
if (!promptInternal(
|
|
||||||
{},
|
|
||||||
user.content,
|
|
||||||
promptTemplate,
|
|
||||||
request.max_tokens,
|
|
||||||
top_k,
|
|
||||||
request.top_p,
|
|
||||||
request.min_p,
|
|
||||||
request.temperature,
|
|
||||||
n_batch,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
assistant.content)
|
|
||||||
) {
|
|
||||||
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
|
|
||||||
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
|
||||||
}
|
|
||||||
promptTokens += m_promptResponseTokens; // previous responses are part of current prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
QString lastMessage = request.messages.last().content;
|
|
||||||
// adds prompt/response items to GUI
|
|
||||||
emit requestServerNewPromptResponsePair(lastMessage); // blocks
|
|
||||||
resetResponse();
|
|
||||||
|
|
||||||
for (int i = 0; i < request.n; ++i) {
|
for (int i = 0; i < request.n; ++i) {
|
||||||
if (!promptInternal(
|
ChatPromptResult result;
|
||||||
m_collections,
|
try {
|
||||||
lastMessage,
|
result = promptInternalChat(m_collections, promptCtx);
|
||||||
promptTemplate,
|
} catch (const std::exception &e) {
|
||||||
request.max_tokens,
|
emit responseChanged(e.what());
|
||||||
top_k,
|
emit responseStopped(0);
|
||||||
request.top_p,
|
|
||||||
request.min_p,
|
|
||||||
request.temperature,
|
|
||||||
n_batch,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n)
|
|
||||||
) {
|
|
||||||
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
|
|
||||||
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
}
|
}
|
||||||
responses.append({response(), m_databaseResults});
|
responses.emplace_back(result.response, result.databaseResults);
|
||||||
// FIXME(jared): these are UI counts and do not include framing tokens, which they should
|
|
||||||
if (i == 0)
|
if (i == 0)
|
||||||
promptTokens += m_promptTokens;
|
promptTokens = result.promptTokens;
|
||||||
responseTokens += m_promptResponseTokens - m_promptTokens;
|
responseTokens += result.responseTokens;
|
||||||
if (i != request.n - 1)
|
|
||||||
resetResponse();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
QJsonObject responseObject {
|
QJsonObject responseObject {
|
||||||
|
@ -33,7 +33,7 @@ public Q_SLOTS:
|
|||||||
void start();
|
void start();
|
||||||
|
|
||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void requestServerNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments = {});
|
void requestResetResponseState();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
auto handleCompletionRequest(const CompletionRequest &request) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>;
|
auto handleCompletionRequest(const CompletionRequest &request) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>;
|
||||||
|
@ -3,23 +3,41 @@
|
|||||||
#include <fmt/base.h>
|
#include <fmt/base.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
|
#include <QByteArray>
|
||||||
|
#include <QJsonValue>
|
||||||
|
#include <QLatin1StringView>
|
||||||
#include <QString>
|
#include <QString>
|
||||||
|
#include <QStringView>
|
||||||
|
#include <QUtf8StringView>
|
||||||
#include <QVariant>
|
#include <QVariant>
|
||||||
|
|
||||||
#include <string>
|
#include <initializer_list>
|
||||||
|
#include <string_view>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
class QJsonObject;
|
||||||
|
|
||||||
|
|
||||||
// fmtlib formatters for QString and QVariant
|
// fmtlib formatters for QString and QVariant
|
||||||
|
|
||||||
#define MAKE_FORMATTER(type, conversion) \
|
#define MAKE_FORMATTER(type, conversion) \
|
||||||
template <> \
|
template <> \
|
||||||
struct fmt::formatter<type, char>: fmt::formatter<std::string, char> { \
|
struct fmt::formatter<type, char>: fmt::formatter<std::string_view, char> { \
|
||||||
template <typename FmtContext> \
|
template <typename FmtContext> \
|
||||||
FmtContext::iterator format(const type &value, FmtContext &ctx) const \
|
FmtContext::iterator format(const type &value, FmtContext &ctx) const \
|
||||||
{ \
|
{ \
|
||||||
return formatter<std::string, char>::format(conversion, ctx); \
|
auto valueUtf8 = (conversion); \
|
||||||
} \
|
std::string_view view(valueUtf8.cbegin(), valueUtf8.cend()); \
|
||||||
|
return formatter<std::string_view, char>::format(view, ctx); \
|
||||||
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
MAKE_FORMATTER(QString, value.toStdString() );
|
MAKE_FORMATTER(QUtf8StringView, value );
|
||||||
MAKE_FORMATTER(QVariant, value.toString().toStdString());
|
MAKE_FORMATTER(QStringView, value.toUtf8() );
|
||||||
|
MAKE_FORMATTER(QString, value.toUtf8() );
|
||||||
|
MAKE_FORMATTER(QVariant, value.toString().toUtf8());
|
||||||
|
|
||||||
|
// alternative to QJsonObject's initializer_list constructor that accepts Latin-1 strings
|
||||||
|
QJsonObject makeJsonObject(std::initializer_list<std::pair<QLatin1StringView, QJsonValue>> args);
|
||||||
|
|
||||||
|
#include "utils.inl"
|
||||||
|
9
gpt4all-chat/src/utils.inl
Normal file
9
gpt4all-chat/src/utils.inl
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#include <QJsonObject>
|
||||||
|
|
||||||
|
inline QJsonObject makeJsonObject(std::initializer_list<std::pair<QLatin1StringView, QJsonValue>> args)
|
||||||
|
{
|
||||||
|
QJsonObject obj;
|
||||||
|
for (auto &arg : args)
|
||||||
|
obj.insert(arg.first, arg.second);
|
||||||
|
return obj;
|
||||||
|
}
|
@ -203,11 +203,10 @@ EXPECTED_MODEL_INFO = {
|
|||||||
EXPECTED_COMPLETIONS_RESPONSE = {
|
EXPECTED_COMPLETIONS_RESPONSE = {
|
||||||
'choices': [
|
'choices': [
|
||||||
{
|
{
|
||||||
'finish_reason': 'stop',
|
'finish_reason': 'length',
|
||||||
'index': 0,
|
'index': 0,
|
||||||
'logprobs': None,
|
'logprobs': None,
|
||||||
'references': None,
|
'text': ' jumps over the lazy dog.\n',
|
||||||
'text': ' jumps over the lazy dog.',
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
'id': 'placeholder',
|
'id': 'placeholder',
|
||||||
@ -242,18 +241,14 @@ def test_with_models(chat_server_with_model: None) -> None:
|
|||||||
'type': 'invalid_request_error',
|
'type': 'invalid_request_error',
|
||||||
}}
|
}}
|
||||||
|
|
||||||
data = {
|
data = dict(
|
||||||
'model': 'Llama 3.2 1B Instruct',
|
model = 'Llama 3.2 1B Instruct',
|
||||||
'prompt': 'The quick brown fox',
|
prompt = 'The quick brown fox',
|
||||||
'temperature': 0,
|
temperature = 0,
|
||||||
}
|
max_tokens = 6,
|
||||||
|
)
|
||||||
response = request.post('completions', data=data)
|
response = request.post('completions', data=data)
|
||||||
assert len(response['choices']) == 1
|
del response['created'] # Remove the dynamic field for comparison
|
||||||
assert response['choices'][0].keys() == {'text', 'index', 'logprobs', 'references', 'finish_reason'}
|
|
||||||
assert response['choices'][0]['text'] == ' jumps over the lazy dog.'
|
|
||||||
assert 'created' in response
|
|
||||||
response.pop('created') # Remove the dynamic field for comparison
|
|
||||||
assert response == EXPECTED_COMPLETIONS_RESPONSE
|
assert response == EXPECTED_COMPLETIONS_RESPONSE
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user