fix chat-style prompt templates (#1970)

Also use a new version of Mistral OpenOrca.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-02-21 15:45:32 -05:00
committed by GitHub
parent b8f5c74f40
commit 4fc4d94be4
22 changed files with 429 additions and 307 deletions

View File

@@ -75,13 +75,18 @@ size_t ChatGPT::restoreState(const uint8_t *src)
}
void ChatGPT::prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx) {
PromptContext &promptCtx,
bool special,
std::string *fakeReply) {
Q_UNUSED(promptCallback);
Q_UNUSED(recalculateCallback);
Q_UNUSED(special);
Q_UNUSED(fakeReply); // FIXME(cebtenzzre): I broke ChatGPT
if (!isModelLoaded()) {
std::cerr << "ChatGPT ERROR: prompt won't work with an unloaded model!\n";
@@ -109,7 +114,7 @@ void ChatGPT::prompt(const std::string &prompt,
QJsonObject promptObject;
promptObject.insert("role", "user");
promptObject.insert("content", QString::fromStdString(prompt));
promptObject.insert("content", QString::fromStdString(promptTemplate).arg(QString::fromStdString(prompt)));
messages.append(promptObject);
root.insert("messages", messages);

View File

@@ -1,6 +1,8 @@
#ifndef CHATGPT_H
#define CHATGPT_H
#include <stdexcept>
#include <QObject>
#include <QNetworkReply>
#include <QNetworkRequest>
@@ -55,10 +57,13 @@ public:
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;
void prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &ctx) override;
PromptContext &ctx,
bool special,
std::string *fakeReply) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;
@@ -69,7 +74,7 @@ public:
QList<QString> context() const { return m_context; }
void setContext(const QList<QString> &context) { m_context = context; }
bool callResponse(int32_t token, const std::string& string);
bool callResponse(int32_t token, const std::string &string);
Q_SIGNALS:
void request(const QString &apiKey,
@@ -80,12 +85,41 @@ protected:
// We have to implement these as they are pure virtual in base class, but we don't actually use
// them as they are only called from the default implementation of 'prompt' which we override and
// completely replace
std::vector<Token> tokenize(PromptContext &, const std::string&) const override { return std::vector<Token>(); }
std::string tokenToString(Token) const override { return std::string(); }
Token sampleToken(PromptContext &ctx) const override { return -1; }
bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; }
int32_t contextLength() const override { return -1; }
const std::vector<Token>& endTokens() const override { static const std::vector<Token> fres; return fres; }
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override {
(void)ctx;
(void)str;
(void)special;
throw std::logic_error("not implemented");
}
std::string tokenToString(Token id) const override {
(void)id;
throw std::logic_error("not implemented");
}
Token sampleToken(PromptContext &ctx) const override {
(void)ctx;
throw std::logic_error("not implemented");
}
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override {
(void)ctx;
(void)tokens;
throw std::logic_error("not implemented");
}
int32_t contextLength() const override {
throw std::logic_error("not implemented");
}
const std::vector<Token> &endTokens() const override {
throw std::logic_error("not implemented");
}
bool shouldAddBOS() const override {
throw std::logic_error("not implemented");
}
private:
std::function<bool(int32_t, const std::string&)> m_responseCallback;

View File

@@ -303,6 +303,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
if (m_llModelInfo.model) {
if (m_llModelInfo.model->isModelBlacklisted(filePath.toStdString())) {
// TODO(cebtenzzre): warn that this model is out-of-date
}
m_llModelInfo.model->setProgressCallback([this](float progress) -> bool {
emit modelLoadingPercentageChanged(progress);
@@ -588,14 +591,11 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
}
// Augment the prompt template with the results if any
QList<QString> augmentedTemplate;
QList<QString> docsContext;
if (!databaseResults.isEmpty())
augmentedTemplate.append("### Context:");
docsContext.append("### Context:");
for (const ResultInfo &info : databaseResults)
augmentedTemplate.append(info.text);
augmentedTemplate.append(promptTemplate);
QString instructPrompt = augmentedTemplate.join("\n").arg(prompt);
docsContext.append(info.text);
int n_threads = MySettings::globalInstance()->threadCount();
@@ -605,7 +605,6 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
emit promptProcessing();
qint32 logitsBefore = m_ctx.logits.size();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
@@ -615,11 +614,16 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
#if defined(DEBUG)
printf("%s", qPrintable(instructPrompt));
printf("%s", qPrintable(prompt));
fflush(stdout);
#endif
m_timer->start();
m_llModelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
if (!docsContext.isEmpty()) {
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
m_llModelInfo.model->prompt(docsContext.join("\n").toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx);
m_ctx.n_predict = old_n_predict; // now we are ready for a response
}
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@@ -720,7 +724,7 @@ void ChatLLM::generateName()
printf("%s", qPrintable(instructPrompt));
fflush(stdout);
#endif
m_llModelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
m_llModelInfo.model->prompt(instructPrompt.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@@ -780,16 +784,6 @@ bool ChatLLM::handleSystemPrompt(int32_t token)
return !m_stopGenerating;
}
bool ChatLLM::handleSystemResponse(int32_t token, const std::string &response)
{
#if defined(DEBUG)
qDebug() << "system response" << m_llmThread.objectName() << token << response << m_stopGenerating;
#endif
Q_UNUSED(token);
Q_UNUSED(response);
return false;
}
bool ChatLLM::handleSystemRecalculate(bool isRecalc)
{
#if defined(DEBUG)
@@ -808,16 +802,6 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
return !m_stopGenerating;
}
bool ChatLLM::handleRestoreStateFromTextResponse(int32_t token, const std::string &response)
{
#if defined(DEBUG)
qDebug() << "restore state from text response" << m_llmThread.objectName() << token << response << m_stopGenerating;
#endif
Q_UNUSED(token);
Q_UNUSED(response);
return false;
}
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
{
#if defined(DEBUG)
@@ -1027,8 +1011,6 @@ void ChatLLM::processSystemPrompt()
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleSystemResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleSystemRecalculate, this, std::placeholders::_1);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
@@ -1051,7 +1033,9 @@ void ChatLLM::processSystemPrompt()
printf("%s", qPrintable(QString::fromStdString(systemPrompt)));
fflush(stdout);
#endif
m_llModelInfo.model->prompt(systemPrompt, promptFunc, responseFunc, recalcFunc, m_ctx);
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
m_llModelInfo.model->prompt(systemPrompt, "%1", promptFunc, nullptr, recalcFunc, m_ctx, true);
m_ctx.n_predict = old_n_predict;
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@@ -1073,8 +1057,6 @@ void ChatLLM::processRestoreStateFromText()
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleRestoreStateFromTextResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
@@ -1094,9 +1076,19 @@ void ChatLLM::processRestoreStateFromText()
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
for (auto pair : m_stateFromText) {
const QString str = pair.first == "Prompt: " ? promptTemplate.arg(pair.second) : pair.second;
m_llModelInfo.model->prompt(str.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
auto it = m_stateFromText.begin();
while (it < m_stateFromText.end()) {
auto &prompt = *it++;
Q_ASSERT(prompt.first == "Prompt: ");
Q_ASSERT(it < m_stateFromText.end());
auto &response = *it++;
Q_ASSERT(response.first != "Prompt: ");
auto responseText = response.second.toStdString();
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
recalcFunc, m_ctx, false, &responseText);
}
if (!m_stopGenerating) {

View File

@@ -17,10 +17,10 @@
},
{
"order": "b",
"md5sum": "48de9538c774188eb25a7e9ee024bbd3",
"md5sum": "f692417a22405d80573ac10cb0cd6c6a",
"name": "Mistral OpenOrca",
"filename": "mistral-7b-openorca.Q4_0.gguf",
"filesize": "4108927744",
"filename": "mistral-7b-openorca.Q4_0.gguf2.gguf",
"filesize": "4108928128",
"requires": "2.5.0",
"ramrequired": "8",
"parameters": "7 billion",
@@ -28,7 +28,7 @@
"type": "Mistral",
"description": "<strong>Best overall fast chat model</strong><br><ul><li>Fast responses</li><li>Chat based model</li><li>Trained by Mistral AI<li>Finetuned on OpenOrca dataset curated via <a href=\"https://atlas.nomic.ai/\">Nomic Atlas</a><li>Licensed for commercial use</ul>",
"url": "https://gpt4all.io/models/gguf/mistral-7b-openorca.Q4_0.gguf",
"promptTemplate": "<|im_start|>user\n%1<|im_end|><|im_start|>assistant\n",
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n",
"systemPrompt": "<|im_start|>system\nYou are MistralOrca, a large language model trained by Alignment Lab AI. For multi-step problems, write out your reasoning for each step.\n<|im_end|>"
},
{
@@ -152,7 +152,7 @@
"type": "MPT",
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat-newbpe-q4_0.gguf",
"promptTemplate": "<|im_start|>user\n%1<|im_end|><|im_start|>assistant\n",
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n",
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>"
},
{

View File

@@ -951,7 +951,7 @@ void ModelList::updateModelsFromDirectory()
processDirectory(localPath);
}
#define MODELS_VERSION 2
#define MODELS_VERSION 3
void ModelList::updateModelsFromJson()
{