Show token generation speed in gui. (#1020)

This commit is contained in:
AT 2023-06-19 11:34:53 -07:00 committed by GitHub
parent fd419caa55
commit 2b6cc99a31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 84 additions and 2 deletions

View File

@ -57,6 +57,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
@ -102,6 +103,8 @@ void Chat::resetResponseState()
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval) if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
return; return;
m_tokenSpeed = QString();
emit tokenSpeedChanged();
m_responseInProgress = true; m_responseInProgress = true;
m_responseState = Chat::LocalDocsRetrieval; m_responseState = Chat::LocalDocsRetrieval;
emit responseInProgressChanged(); emit responseInProgressChanged();
@ -187,6 +190,9 @@ void Chat::promptProcessing()
void Chat::responseStopped() void Chat::responseStopped()
{ {
m_tokenSpeed = QString();
emit tokenSpeedChanged();
const QString chatResponse = response(); const QString chatResponse = response();
QList<QString> references; QList<QString> references;
QList<QString> referencesContext; QList<QString> referencesContext;
@ -336,6 +342,12 @@ void Chat::handleModelLoadingError(const QString &error)
emit modelLoadingErrorChanged(); emit modelLoadingErrorChanged();
} }
void Chat::handleTokenSpeedChanged(const QString &tokenSpeed)
{
m_tokenSpeed = tokenSpeed;
emit tokenSpeedChanged();
}
bool Chat::serialize(QDataStream &stream, int version) const bool Chat::serialize(QDataStream &stream, int version) const
{ {
stream << m_creationDate; stream << m_creationDate;

View File

@ -25,6 +25,7 @@ class Chat : public QObject
Q_PROPERTY(QString responseState READ responseState NOTIFY responseStateChanged) Q_PROPERTY(QString responseState READ responseState NOTIFY responseStateChanged)
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged) Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged) Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged)
Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged);
QML_ELEMENT QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!") QML_UNCREATABLE("Only creatable from c++!")
@ -91,6 +92,8 @@ public:
QString modelLoadingError() const { return m_modelLoadingError; } QString modelLoadingError() const { return m_modelLoadingError; }
QString tokenSpeed() const { return m_tokenSpeed; }
public Q_SLOTS: public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt); void serverNewPromptResponsePair(const QString &prompt);
@ -118,6 +121,7 @@ Q_SIGNALS:
void modelLoadingErrorChanged(); void modelLoadingErrorChanged();
void isServerChanged(); void isServerChanged();
void collectionListChanged(); void collectionListChanged();
void tokenSpeedChanged();
private Q_SLOTS: private Q_SLOTS:
void handleResponseChanged(); void handleResponseChanged();
@ -128,6 +132,7 @@ private Q_SLOTS:
void handleRecalculating(); void handleRecalculating();
void handleModelNameChanged(); void handleModelNameChanged();
void handleModelLoadingError(const QString &error); void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed);
private: private:
QString m_id; QString m_id;
@ -135,6 +140,7 @@ private:
QString m_userName; QString m_userName;
QString m_savedModelName; QString m_savedModelName;
QString m_modelLoadingError; QString m_modelLoadingError;
QString m_tokenSpeed;
QList<QString> m_collections; QList<QString> m_collections;
ChatModel *m_chatModel; ChatModel *m_chatModel;
bool m_responseInProgress; bool m_responseInProgress;

View File

@ -94,6 +94,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_responseLogits(0) , m_responseLogits(0)
, m_isRecalc(false) , m_isRecalc(false)
, m_chat(parent) , m_chat(parent)
, m_timer(nullptr)
, m_isServer(isServer) , m_isServer(isServer)
, m_isChatGPT(false) , m_isChatGPT(false)
{ {
@ -103,7 +104,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued Qt::QueuedConnection); // explicitly queued
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted); connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
// The following are blocking operations and will block the llm thread // The following are blocking operations and will block the llm thread
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
@ -126,6 +127,13 @@ ChatLLM::~ChatLLM()
} }
} }
void ChatLLM::handleThreadStarted()
{
m_timer = new TokenTimer(this);
connect(m_timer, &TokenTimer::report, this, &ChatLLM::reportSpeed);
emit threadStarted();
}
bool ChatLLM::loadDefaultModel() bool ChatLLM::loadDefaultModel()
{ {
const QList<QString> models = m_chat->modelList(); const QList<QString> models = m_chat->modelList();
@ -367,6 +375,7 @@ bool ChatLLM::handlePrompt(int32_t token)
#endif #endif
++m_promptTokens; ++m_promptTokens;
++m_promptResponseTokens; ++m_promptResponseTokens;
m_timer->inc();
return !m_stopGenerating; return !m_stopGenerating;
} }
@ -387,6 +396,7 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not // m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt // the entire context window which we can reset on regenerate prompt
++m_promptResponseTokens; ++m_promptResponseTokens;
m_timer->inc();
Q_ASSERT(!response.empty()); Q_ASSERT(!response.empty());
m_response.append(response); m_response.append(response);
emit responseChanged(); emit responseChanged();
@ -441,11 +451,13 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
printf("%s", qPrintable(instructPrompt)); printf("%s", qPrintable(instructPrompt));
fflush(stdout); fflush(stdout);
#endif #endif
m_timer->start();
m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx); m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
#if defined(DEBUG) #if defined(DEBUG)
printf("\n"); printf("\n");
fflush(stdout); fflush(stdout);
#endif #endif
m_timer->stop();
m_responseLogits += m_ctx.logits.size() - logitsBefore; m_responseLogits += m_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response); std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) { if (trimmed != m_response) {

View File

@ -23,6 +23,46 @@ struct LLModelInfo {
// must be able to serialize the information even if it is in the unloaded state // must be able to serialize the information even if it is in the unloaded state
}; };
class TokenTimer : public QObject {
Q_OBJECT
public:
explicit TokenTimer(QObject *parent)
: QObject(parent)
, m_elapsed(0) {}
static int rollingAverage(int oldAvg, int newNumber, int n)
{
// i.e. to calculate the new average after then nth number,
// you multiply the old average by n1, add the new number, and divide the total by n.
return qRound(((float(oldAvg) * (n - 1)) + newNumber) / float(n));
}
void start() { m_tokens = 0; m_elapsed = 0; m_time.invalidate(); }
void stop() { handleTimeout(); }
void inc() {
if (!m_time.isValid())
m_time.start();
++m_tokens;
if (m_time.elapsed() > 999)
handleTimeout();
}
Q_SIGNALS:
void report(const QString &speed);
private Q_SLOTS:
void handleTimeout()
{
m_elapsed += m_time.restart();
emit report(QString("%1 tokens/sec").arg(m_tokens / float(m_elapsed / 1000.0f), 0, 'g', 2));
}
private:
QElapsedTimer m_time;
qint64 m_elapsed;
quint32 m_tokens;
};
class Chat; class Chat;
class ChatLLM : public QObject class ChatLLM : public QObject
{ {
@ -73,6 +113,7 @@ public Q_SLOTS:
void generateName(); void generateName();
void handleChatIdChanged(); void handleChatIdChanged();
void handleShouldBeLoadedChanged(); void handleShouldBeLoadedChanged();
void handleThreadStarted();
Q_SIGNALS: Q_SIGNALS:
void isModelLoadedChanged(); void isModelLoadedChanged();
@ -89,7 +130,7 @@ Q_SIGNALS:
void threadStarted(); void threadStarted();
void shouldBeLoadedChanged(); void shouldBeLoadedChanged();
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results); void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void reportSpeed(const QString &speed);
protected: protected:
bool handlePrompt(int32_t token); bool handlePrompt(int32_t token);
@ -112,6 +153,7 @@ protected:
quint32 m_responseLogits; quint32 m_responseLogits;
QString m_modelName; QString m_modelName;
Chat *m_chat; Chat *m_chat;
TokenTimer *m_timer;
QByteArray m_state; QByteArray m_state;
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;

View File

@ -845,6 +845,16 @@ Window {
Accessible.description: qsTr("Controls generation of the response") Accessible.description: qsTr("Controls generation of the response")
} }
Text {
id: speed
anchors.bottom: textInputView.top
anchors.bottomMargin: 20
anchors.right: parent.right
anchors.rightMargin: 30
color: theme.mutedTextColor
text: currentChat.tokenSpeed
}
RectangularGlow { RectangularGlow {
id: effect id: effect
anchors.fill: textInputView anchors.fill: textInputView