fix regressions in system prompt handling (#2219)

* python: fix system prompt being ignored
* fix unintended whitespace after system prompt

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-04-15 11:39:48 -04:00
committed by GitHub
parent 2273cf145e
commit ac498f79ac
4 changed files with 12 additions and 19 deletions

View File

@@ -733,23 +733,13 @@ void ChatLLM::generateName()
if (!isModelLoaded())
return;
QString instructPrompt("### Instruction:\n"
"Describe response above in three words.\n"
"### Response:\n");
std::string instructPrompt("### Instruction:\n%1\n### Response:\n"); // standard Alpaca
auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1);
LLModel::PromptContext ctx = m_ctx;
#if defined(DEBUG)
printf("%s", qPrintable(instructPrompt));
fflush(stdout);
#endif
m_llModelInfo.model->prompt(instructPrompt.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
#endif
m_llModelInfo.model->prompt("Describe response above in three words.", instructPrompt, promptFunc, responseFunc,
recalcFunc, ctx);
std::string trimmed = trim_whitespace(m_nameResponse);
if (trimmed != m_nameResponse) {
m_nameResponse = trimmed;
@@ -1056,7 +1046,8 @@ void ChatLLM::processSystemPrompt()
fflush(stdout);
#endif
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
m_llModelInfo.model->prompt(systemPrompt, "%1", promptFunc, nullptr, recalcFunc, m_ctx, true);
// use "%1%2" and not "%1" to avoid implicit whitespace
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, recalcFunc, m_ctx, true);
m_ctx.n_predict = old_n_predict;
#if defined(DEBUG)
printf("\n");