chat: generate follow-up questions after response (#2634)

* user can configure the prompt and when they appear
* also make the name generation prompt configurable

Signed-off-by: Adam Treat <treat.adam@gmail.com>
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
AT
2024-07-10 15:45:20 -04:00
committed by GitHub
parent ef4e362d92
commit 66bc04aa8e
14 changed files with 621 additions and 138 deletions

View File

@@ -750,7 +750,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
if (!databaseResults.isEmpty()) {
QStringList results;
for (const ResultInfo &info : databaseResults)
results << u"Collection: %1\nPath: %2\nSnippet: %3"_s.arg(info.collection, info.path, info.text);
results << u"Collection: %1\nPath: %2\nExcerpt: %3"_s.arg(info.collection, info.path, info.text);
// FIXME(jared): use a Jinja prompt template instead of hardcoded Alpaca-style localdocs template
docsContext = u"### Context:\n%1\n\n"_s.arg(results.join("\n\n"));
@@ -797,7 +797,13 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_response = trimmed;
emit responseChanged(QString::fromStdString(m_response));
}
emit responseStopped(elapsed);
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
if (mode == SuggestionMode::On || (!databaseResults.isEmpty() && mode == SuggestionMode::LocalDocsOnly))
generateQuestions(elapsed);
else
emit responseStopped(elapsed);
m_pristineLoadedState = false;
return true;
}
@@ -875,13 +881,19 @@ void ChatLLM::generateName()
if (!isModelLoaded())
return;
const QString chatNamePrompt = MySettings::globalInstance()->modelChatNamePrompt(m_modelInfo);
if (chatNamePrompt.trimmed().isEmpty()) {
qWarning() << "ChatLLM: not generating chat name because prompt is empty";
return;
}
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1);
LLModel::PromptContext ctx = m_ctx;
m_llModelInfo.model->prompt("Describe the above conversation in seven words or less.",
promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(),
promptFunc, responseFunc, recalcFunc, ctx);
std::string trimmed = trim_whitespace(m_nameResponse);
if (trimmed != m_nameResponse) {
m_nameResponse = trimmed;
@@ -901,7 +913,6 @@ bool ChatLLM::handleNamePrompt(int32_t token)
qDebug() << "name prompt" << m_llmThread.objectName() << token;
#endif
Q_UNUSED(token);
qt_noop();
return !m_stopGenerating;
}
@@ -925,10 +936,84 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc)
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
qt_noop();
return true;
}
bool ChatLLM::handleQuestionPrompt(int32_t token)
{
#if defined(DEBUG)
qDebug() << "question prompt" << m_llmThread.objectName() << token;
#endif
Q_UNUSED(token);
return !m_stopGenerating;
}
bool ChatLLM::handleQuestionResponse(int32_t token, const std::string &response)
{
#if defined(DEBUG)
qDebug() << "question response" << m_llmThread.objectName() << token << response;
#endif
Q_UNUSED(token);
// add token to buffer
m_questionResponse.append(response);
// match whole question sentences
static const QRegularExpression reQuestion(R"(\b(What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)");
// extract all questions from response
int lastMatchEnd = -1;
for (const auto &match : reQuestion.globalMatch(m_questionResponse)) {
lastMatchEnd = match.capturedEnd();
emit generatedQuestionFinished(match.captured());
}
// remove processed input from buffer
if (lastMatchEnd != -1)
m_questionResponse.erase(m_questionResponse.cbegin(), m_questionResponse.cbegin() + lastMatchEnd);
return true;
}
bool ChatLLM::handleQuestionRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return true;
}
void ChatLLM::generateQuestions(qint64 elapsed)
{
Q_ASSERT(isModelLoaded());
if (!isModelLoaded()) {
emit responseStopped(elapsed);
return;
}
const std::string suggestedFollowUpPrompt = MySettings::globalInstance()->modelSuggestedFollowUpPrompt(m_modelInfo).toStdString();
if (QString::fromStdString(suggestedFollowUpPrompt).trimmed().isEmpty()) {
emit responseStopped(elapsed);
return;
}
emit generatingQuestions();
m_questionResponse.clear();
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
auto promptFunc = std::bind(&ChatLLM::handleQuestionPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleQuestionRecalculate, this, std::placeholders::_1);
LLModel::PromptContext ctx = m_ctx;
QElapsedTimer totalTime;
totalTime.start();
m_llModelInfo.model->prompt(suggestedFollowUpPrompt,
promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
elapsed += totalTime.elapsed();
emit responseStopped(elapsed);
}
bool ChatLLM::handleSystemPrompt(int32_t token)
{
#if defined(DEBUG)