mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-13 22:39:11 +00:00
fix chat-style prompt templates (#1970)
Also use a new version of Mistral OpenOrca. Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -75,13 +75,18 @@ size_t ChatGPT::restoreState(const uint8_t *src)
|
||||
}
|
||||
|
||||
void ChatGPT::prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx) {
|
||||
PromptContext &promptCtx,
|
||||
bool special,
|
||||
std::string *fakeReply) {
|
||||
|
||||
Q_UNUSED(promptCallback);
|
||||
Q_UNUSED(recalculateCallback);
|
||||
Q_UNUSED(special);
|
||||
Q_UNUSED(fakeReply); // FIXME(cebtenzzre): I broke ChatGPT
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
std::cerr << "ChatGPT ERROR: prompt won't work with an unloaded model!\n";
|
||||
@@ -109,7 +114,7 @@ void ChatGPT::prompt(const std::string &prompt,
|
||||
|
||||
QJsonObject promptObject;
|
||||
promptObject.insert("role", "user");
|
||||
promptObject.insert("content", QString::fromStdString(prompt));
|
||||
promptObject.insert("content", QString::fromStdString(promptTemplate).arg(QString::fromStdString(prompt)));
|
||||
messages.append(promptObject);
|
||||
root.insert("messages", messages);
|
||||
|
||||
|
@@ -1,6 +1,8 @@
|
||||
#ifndef CHATGPT_H
|
||||
#define CHATGPT_H
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include <QObject>
|
||||
#include <QNetworkReply>
|
||||
#include <QNetworkRequest>
|
||||
@@ -55,10 +57,13 @@ public:
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
void prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx) override;
|
||||
PromptContext &ctx,
|
||||
bool special,
|
||||
std::string *fakeReply) override;
|
||||
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() const override;
|
||||
@@ -69,7 +74,7 @@ public:
|
||||
QList<QString> context() const { return m_context; }
|
||||
void setContext(const QList<QString> &context) { m_context = context; }
|
||||
|
||||
bool callResponse(int32_t token, const std::string& string);
|
||||
bool callResponse(int32_t token, const std::string &string);
|
||||
|
||||
Q_SIGNALS:
|
||||
void request(const QString &apiKey,
|
||||
@@ -80,12 +85,41 @@ protected:
|
||||
// We have to implement these as they are pure virtual in base class, but we don't actually use
|
||||
// them as they are only called from the default implementation of 'prompt' which we override and
|
||||
// completely replace
|
||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override { return std::vector<Token>(); }
|
||||
std::string tokenToString(Token) const override { return std::string(); }
|
||||
Token sampleToken(PromptContext &ctx) const override { return -1; }
|
||||
bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; }
|
||||
int32_t contextLength() const override { return -1; }
|
||||
const std::vector<Token>& endTokens() const override { static const std::vector<Token> fres; return fres; }
|
||||
|
||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override {
|
||||
(void)ctx;
|
||||
(void)str;
|
||||
(void)special;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
std::string tokenToString(Token id) const override {
|
||||
(void)id;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
Token sampleToken(PromptContext &ctx) const override {
|
||||
(void)ctx;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override {
|
||||
(void)ctx;
|
||||
(void)tokens;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
int32_t contextLength() const override {
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
const std::vector<Token> &endTokens() const override {
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
bool shouldAddBOS() const override {
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<bool(int32_t, const std::string&)> m_responseCallback;
|
||||
|
@@ -303,6 +303,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
|
||||
|
||||
if (m_llModelInfo.model) {
|
||||
if (m_llModelInfo.model->isModelBlacklisted(filePath.toStdString())) {
|
||||
// TODO(cebtenzzre): warn that this model is out-of-date
|
||||
}
|
||||
|
||||
m_llModelInfo.model->setProgressCallback([this](float progress) -> bool {
|
||||
emit modelLoadingPercentageChanged(progress);
|
||||
@@ -588,14 +591,11 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
}
|
||||
|
||||
// Augment the prompt template with the results if any
|
||||
QList<QString> augmentedTemplate;
|
||||
QList<QString> docsContext;
|
||||
if (!databaseResults.isEmpty())
|
||||
augmentedTemplate.append("### Context:");
|
||||
docsContext.append("### Context:");
|
||||
for (const ResultInfo &info : databaseResults)
|
||||
augmentedTemplate.append(info.text);
|
||||
augmentedTemplate.append(promptTemplate);
|
||||
|
||||
QString instructPrompt = augmentedTemplate.join("\n").arg(prompt);
|
||||
docsContext.append(info.text);
|
||||
|
||||
int n_threads = MySettings::globalInstance()->threadCount();
|
||||
|
||||
@@ -605,7 +605,6 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
|
||||
emit promptProcessing();
|
||||
qint32 logitsBefore = m_ctx.logits.size();
|
||||
m_ctx.n_predict = n_predict;
|
||||
m_ctx.top_k = top_k;
|
||||
m_ctx.top_p = top_p;
|
||||
@@ -615,11 +614,16 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
m_llModelInfo.model->setThreadCount(n_threads);
|
||||
#if defined(DEBUG)
|
||||
printf("%s", qPrintable(instructPrompt));
|
||||
printf("%s", qPrintable(prompt));
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_timer->start();
|
||||
m_llModelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
if (!docsContext.isEmpty()) {
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
|
||||
m_llModelInfo.model->prompt(docsContext.join("\n").toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
m_ctx.n_predict = old_n_predict; // now we are ready for a response
|
||||
}
|
||||
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
@@ -720,7 +724,7 @@ void ChatLLM::generateName()
|
||||
printf("%s", qPrintable(instructPrompt));
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_llModelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
|
||||
m_llModelInfo.model->prompt(instructPrompt.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, ctx);
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
@@ -780,16 +784,6 @@ bool ChatLLM::handleSystemPrompt(int32_t token)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleSystemResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "system response" << m_llmThread.objectName() << token << response << m_stopGenerating;
|
||||
#endif
|
||||
Q_UNUSED(token);
|
||||
Q_UNUSED(response);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleSystemRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
@@ -808,16 +802,6 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRestoreStateFromTextResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "restore state from text response" << m_llmThread.objectName() << token << response << m_stopGenerating;
|
||||
#endif
|
||||
Q_UNUSED(token);
|
||||
Q_UNUSED(response);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
@@ -1027,8 +1011,6 @@ void ChatLLM::processSystemPrompt()
|
||||
m_ctx = LLModel::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleSystemResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleSystemRecalculate, this, std::placeholders::_1);
|
||||
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
@@ -1051,7 +1033,9 @@ void ChatLLM::processSystemPrompt()
|
||||
printf("%s", qPrintable(QString::fromStdString(systemPrompt)));
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_llModelInfo.model->prompt(systemPrompt, promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
|
||||
m_llModelInfo.model->prompt(systemPrompt, "%1", promptFunc, nullptr, recalcFunc, m_ctx, true);
|
||||
m_ctx.n_predict = old_n_predict;
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
@@ -1073,8 +1057,6 @@ void ChatLLM::processRestoreStateFromText()
|
||||
m_ctx = LLModel::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleRestoreStateFromTextResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
|
||||
|
||||
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
@@ -1094,9 +1076,19 @@ void ChatLLM::processRestoreStateFromText()
|
||||
m_ctx.repeat_penalty = repeat_penalty;
|
||||
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
m_llModelInfo.model->setThreadCount(n_threads);
|
||||
for (auto pair : m_stateFromText) {
|
||||
const QString str = pair.first == "Prompt: " ? promptTemplate.arg(pair.second) : pair.second;
|
||||
m_llModelInfo.model->prompt(str.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
|
||||
auto it = m_stateFromText.begin();
|
||||
while (it < m_stateFromText.end()) {
|
||||
auto &prompt = *it++;
|
||||
Q_ASSERT(prompt.first == "Prompt: ");
|
||||
Q_ASSERT(it < m_stateFromText.end());
|
||||
|
||||
auto &response = *it++;
|
||||
Q_ASSERT(response.first != "Prompt: ");
|
||||
auto responseText = response.second.toStdString();
|
||||
|
||||
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
|
||||
recalcFunc, m_ctx, false, &responseText);
|
||||
}
|
||||
|
||||
if (!m_stopGenerating) {
|
||||
|
@@ -17,10 +17,10 @@
|
||||
},
|
||||
{
|
||||
"order": "b",
|
||||
"md5sum": "48de9538c774188eb25a7e9ee024bbd3",
|
||||
"md5sum": "f692417a22405d80573ac10cb0cd6c6a",
|
||||
"name": "Mistral OpenOrca",
|
||||
"filename": "mistral-7b-openorca.Q4_0.gguf",
|
||||
"filesize": "4108927744",
|
||||
"filename": "mistral-7b-openorca.Q4_0.gguf2.gguf",
|
||||
"filesize": "4108928128",
|
||||
"requires": "2.5.0",
|
||||
"ramrequired": "8",
|
||||
"parameters": "7 billion",
|
||||
@@ -28,7 +28,7 @@
|
||||
"type": "Mistral",
|
||||
"description": "<strong>Best overall fast chat model</strong><br><ul><li>Fast responses</li><li>Chat based model</li><li>Trained by Mistral AI<li>Finetuned on OpenOrca dataset curated via <a href=\"https://atlas.nomic.ai/\">Nomic Atlas</a><li>Licensed for commercial use</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/mistral-7b-openorca.Q4_0.gguf",
|
||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|><|im_start|>assistant\n",
|
||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n",
|
||||
"systemPrompt": "<|im_start|>system\nYou are MistralOrca, a large language model trained by Alignment Lab AI. For multi-step problems, write out your reasoning for each step.\n<|im_end|>"
|
||||
},
|
||||
{
|
||||
@@ -152,7 +152,7 @@
|
||||
"type": "MPT",
|
||||
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat-newbpe-q4_0.gguf",
|
||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|><|im_start|>assistant\n",
|
||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n",
|
||||
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>"
|
||||
},
|
||||
{
|
||||
|
@@ -951,7 +951,7 @@ void ModelList::updateModelsFromDirectory()
|
||||
processDirectory(localPath);
|
||||
}
|
||||
|
||||
#define MODELS_VERSION 2
|
||||
#define MODELS_VERSION 3
|
||||
|
||||
void ModelList::updateModelsFromJson()
|
||||
{
|
||||
|
Reference in New Issue
Block a user