mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-08 20:09:12 +00:00
fix chat-style prompt templates (#1970)
Also use a new version of Mistral OpenOrca. Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -303,6 +303,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
|
||||
|
||||
if (m_llModelInfo.model) {
|
||||
if (m_llModelInfo.model->isModelBlacklisted(filePath.toStdString())) {
|
||||
// TODO(cebtenzzre): warn that this model is out-of-date
|
||||
}
|
||||
|
||||
m_llModelInfo.model->setProgressCallback([this](float progress) -> bool {
|
||||
emit modelLoadingPercentageChanged(progress);
|
||||
@@ -588,14 +591,11 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
}
|
||||
|
||||
// Augment the prompt template with the results if any
|
||||
QList<QString> augmentedTemplate;
|
||||
QList<QString> docsContext;
|
||||
if (!databaseResults.isEmpty())
|
||||
augmentedTemplate.append("### Context:");
|
||||
docsContext.append("### Context:");
|
||||
for (const ResultInfo &info : databaseResults)
|
||||
augmentedTemplate.append(info.text);
|
||||
augmentedTemplate.append(promptTemplate);
|
||||
|
||||
QString instructPrompt = augmentedTemplate.join("\n").arg(prompt);
|
||||
docsContext.append(info.text);
|
||||
|
||||
int n_threads = MySettings::globalInstance()->threadCount();
|
||||
|
||||
@@ -605,7 +605,6 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
|
||||
emit promptProcessing();
|
||||
qint32 logitsBefore = m_ctx.logits.size();
|
||||
m_ctx.n_predict = n_predict;
|
||||
m_ctx.top_k = top_k;
|
||||
m_ctx.top_p = top_p;
|
||||
@@ -615,11 +614,16 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
m_llModelInfo.model->setThreadCount(n_threads);
|
||||
#if defined(DEBUG)
|
||||
printf("%s", qPrintable(instructPrompt));
|
||||
printf("%s", qPrintable(prompt));
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_timer->start();
|
||||
m_llModelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
if (!docsContext.isEmpty()) {
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
|
||||
m_llModelInfo.model->prompt(docsContext.join("\n").toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
m_ctx.n_predict = old_n_predict; // now we are ready for a response
|
||||
}
|
||||
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
@@ -720,7 +724,7 @@ void ChatLLM::generateName()
|
||||
printf("%s", qPrintable(instructPrompt));
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_llModelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
|
||||
m_llModelInfo.model->prompt(instructPrompt.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, ctx);
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
@@ -780,16 +784,6 @@ bool ChatLLM::handleSystemPrompt(int32_t token)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleSystemResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "system response" << m_llmThread.objectName() << token << response << m_stopGenerating;
|
||||
#endif
|
||||
Q_UNUSED(token);
|
||||
Q_UNUSED(response);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleSystemRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
@@ -808,16 +802,6 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRestoreStateFromTextResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "restore state from text response" << m_llmThread.objectName() << token << response << m_stopGenerating;
|
||||
#endif
|
||||
Q_UNUSED(token);
|
||||
Q_UNUSED(response);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
@@ -1027,8 +1011,6 @@ void ChatLLM::processSystemPrompt()
|
||||
m_ctx = LLModel::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleSystemResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleSystemRecalculate, this, std::placeholders::_1);
|
||||
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
@@ -1051,7 +1033,9 @@ void ChatLLM::processSystemPrompt()
|
||||
printf("%s", qPrintable(QString::fromStdString(systemPrompt)));
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_llModelInfo.model->prompt(systemPrompt, promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
|
||||
m_llModelInfo.model->prompt(systemPrompt, "%1", promptFunc, nullptr, recalcFunc, m_ctx, true);
|
||||
m_ctx.n_predict = old_n_predict;
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
@@ -1073,8 +1057,6 @@ void ChatLLM::processRestoreStateFromText()
|
||||
m_ctx = LLModel::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleRestoreStateFromTextResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
|
||||
|
||||
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
@@ -1094,9 +1076,19 @@ void ChatLLM::processRestoreStateFromText()
|
||||
m_ctx.repeat_penalty = repeat_penalty;
|
||||
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
m_llModelInfo.model->setThreadCount(n_threads);
|
||||
for (auto pair : m_stateFromText) {
|
||||
const QString str = pair.first == "Prompt: " ? promptTemplate.arg(pair.second) : pair.second;
|
||||
m_llModelInfo.model->prompt(str.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
|
||||
auto it = m_stateFromText.begin();
|
||||
while (it < m_stateFromText.end()) {
|
||||
auto &prompt = *it++;
|
||||
Q_ASSERT(prompt.first == "Prompt: ");
|
||||
Q_ASSERT(it < m_stateFromText.end());
|
||||
|
||||
auto &response = *it++;
|
||||
Q_ASSERT(response.first != "Prompt: ");
|
||||
auto responseText = response.second.toStdString();
|
||||
|
||||
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
|
||||
recalcFunc, m_ctx, false, &responseText);
|
||||
}
|
||||
|
||||
if (!m_stopGenerating) {
|
||||
|
Reference in New Issue
Block a user