Fix the way we're injecting the context back into the model for web search.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat 2024-07-27 11:11:41 -04:00
parent c78c95ab42
commit dda59a97a6
3 changed files with 12 additions and 5 deletions

View File

@ -116,6 +116,7 @@ void Chat::resetResponseState()
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval) if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
return; return;
m_sourceExcerpts = QList<SourceExcerpt>();
m_generatedQuestions = QList<QString>(); m_generatedQuestions = QList<QString>();
emit generatedQuestionsChanged(); emit generatedQuestionsChanged();
m_tokenSpeed = QString(); m_tokenSpeed = QString();
@ -136,6 +137,7 @@ void Chat::prompt(const QString &prompt)
void Chat::regenerateResponse() void Chat::regenerateResponse()
{ {
const int index = m_chatModel->count() - 1; const int index = m_chatModel->count() - 1;
m_sourceExcerpts = QList<SourceExcerpt>();
m_chatModel->updateSources(index, QList<SourceExcerpt>()); m_chatModel->updateSources(index, QList<SourceExcerpt>());
emit regenerateResponseRequested(); emit regenerateResponseRequested();
} }

View File

@ -869,13 +869,13 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32
static QRegularExpression re(R"(brave_search\.call\(query=\"([^\"]+)\"\))"); static QRegularExpression re(R"(brave_search\.call\(query=\"([^\"]+)\"\))");
QRegularExpressionMatch match = re.match(toolCall); QRegularExpressionMatch match = re.match(toolCall);
QString prompt("<|start_header_id|>ipython<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2"); QString promptTemplate("<|start_header_id|>ipython<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2");
QString query; QString query;
if (match.hasMatch()) { if (match.hasMatch()) {
query = match.captured(1); query = match.captured(1);
} else { } else {
qWarning() << "WARNING: Could not find the tool for " << toolCall; qWarning() << "WARNING: Could not find the tool for " << toolCall;
return promptInternal(QList<QString>()/*collectionList*/, prompt.arg(QString()), QString("%1") /*promptTemplate*/, return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, promptTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens);
} }
@ -887,7 +887,7 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32
emit sourceExcerptsChanged(braveResponse.second); emit sourceExcerptsChanged(braveResponse.second);
return promptInternal(QList<QString>()/*collectionList*/, prompt.arg(braveResponse.first), QString("%1") /*promptTemplate*/, return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, promptTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens);
} }

View File

@ -219,8 +219,13 @@ public:
if (index < 0 || index >= m_chatItems.size()) return; if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index]; ChatItem &item = m_chatItems[index];
item.sources = sources; if (sources.isEmpty()) {
item.consolidatedSources = consolidateSources(sources); item.sources.clear();
item.consolidatedSources.clear();
} else {
item.sources << sources;
item.consolidatedSources << consolidateSources(sources);
}
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole}); emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole});
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole}); emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole});
} }