fix chat-style prompt templates (#1970)

Also use a new version of Mistral OpenOrca.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-02-21 15:45:32 -05:00
committed by GitHub
parent b8f5c74f40
commit 4fc4d94be4
22 changed files with 429 additions and 307 deletions

View File

@@ -303,6 +303,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
if (m_llModelInfo.model) {
if (m_llModelInfo.model->isModelBlacklisted(filePath.toStdString())) {
// TODO(cebtenzzre): warn that this model is out-of-date
}
m_llModelInfo.model->setProgressCallback([this](float progress) -> bool {
emit modelLoadingPercentageChanged(progress);
@@ -588,14 +591,11 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
}
// Augment the prompt template with the results if any
QList<QString> augmentedTemplate;
QList<QString> docsContext;
if (!databaseResults.isEmpty())
augmentedTemplate.append("### Context:");
docsContext.append("### Context:");
for (const ResultInfo &info : databaseResults)
augmentedTemplate.append(info.text);
augmentedTemplate.append(promptTemplate);
QString instructPrompt = augmentedTemplate.join("\n").arg(prompt);
docsContext.append(info.text);
int n_threads = MySettings::globalInstance()->threadCount();
@@ -605,7 +605,6 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
emit promptProcessing();
qint32 logitsBefore = m_ctx.logits.size();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
@@ -615,11 +614,16 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
#if defined(DEBUG)
printf("%s", qPrintable(instructPrompt));
printf("%s", qPrintable(prompt));
fflush(stdout);
#endif
m_timer->start();
m_llModelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
if (!docsContext.isEmpty()) {
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
m_llModelInfo.model->prompt(docsContext.join("\n").toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx);
m_ctx.n_predict = old_n_predict; // now we are ready for a response
}
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@@ -720,7 +724,7 @@ void ChatLLM::generateName()
printf("%s", qPrintable(instructPrompt));
fflush(stdout);
#endif
m_llModelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
m_llModelInfo.model->prompt(instructPrompt.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@@ -780,16 +784,6 @@ bool ChatLLM::handleSystemPrompt(int32_t token)
return !m_stopGenerating;
}
bool ChatLLM::handleSystemResponse(int32_t token, const std::string &response)
{
#if defined(DEBUG)
qDebug() << "system response" << m_llmThread.objectName() << token << response << m_stopGenerating;
#endif
Q_UNUSED(token);
Q_UNUSED(response);
return false;
}
bool ChatLLM::handleSystemRecalculate(bool isRecalc)
{
#if defined(DEBUG)
@@ -808,16 +802,6 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
return !m_stopGenerating;
}
bool ChatLLM::handleRestoreStateFromTextResponse(int32_t token, const std::string &response)
{
#if defined(DEBUG)
qDebug() << "restore state from text response" << m_llmThread.objectName() << token << response << m_stopGenerating;
#endif
Q_UNUSED(token);
Q_UNUSED(response);
return false;
}
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
{
#if defined(DEBUG)
@@ -1027,8 +1011,6 @@ void ChatLLM::processSystemPrompt()
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleSystemResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleSystemRecalculate, this, std::placeholders::_1);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
@@ -1051,7 +1033,9 @@ void ChatLLM::processSystemPrompt()
printf("%s", qPrintable(QString::fromStdString(systemPrompt)));
fflush(stdout);
#endif
m_llModelInfo.model->prompt(systemPrompt, promptFunc, responseFunc, recalcFunc, m_ctx);
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
m_llModelInfo.model->prompt(systemPrompt, "%1", promptFunc, nullptr, recalcFunc, m_ctx, true);
m_ctx.n_predict = old_n_predict;
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@@ -1073,8 +1057,6 @@ void ChatLLM::processRestoreStateFromText()
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleRestoreStateFromTextResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
@@ -1094,9 +1076,19 @@ void ChatLLM::processRestoreStateFromText()
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
for (auto pair : m_stateFromText) {
const QString str = pair.first == "Prompt: " ? promptTemplate.arg(pair.second) : pair.second;
m_llModelInfo.model->prompt(str.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
auto it = m_stateFromText.begin();
while (it < m_stateFromText.end()) {
auto &prompt = *it++;
Q_ASSERT(prompt.first == "Prompt: ");
Q_ASSERT(it < m_stateFromText.end());
auto &response = *it++;
Q_ASSERT(response.first != "Prompt: ");
auto responseText = response.second.toStdString();
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
recalcFunc, m_ctx, false, &responseText);
}
if (!m_stopGenerating) {