From 00cb5fe2a5c29e6dee858b1a56603e21c5129ea9 Mon Sep 17 00:00:00 2001 From: Aaron Miller Date: Tue, 18 Apr 2023 06:46:03 -0700 Subject: [PATCH] Add thread count setting --- gptj.cpp | 8 ++++++++ gptj.h | 4 +++- llm.cpp | 22 ++++++++++++++++++++++ llm.h | 10 ++++++++++ llmodel.h | 4 +++- main.qml | 38 ++++++++++++++++++++++++++++++++------ 6 files changed, 78 insertions(+), 8 deletions(-) diff --git a/gptj.cpp b/gptj.cpp index db3da3b7..0e40a4a1 100644 --- a/gptj.cpp +++ b/gptj.cpp @@ -659,6 +659,14 @@ bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) { return true; } +void GPTJ::setThreadCount(int32_t n_threads) { + d_ptr->n_threads = n_threads; +} + +int32_t GPTJ::threadCount() { + return d_ptr->n_threads; +} + GPTJ::~GPTJ() { ggml_free(d_ptr->model.ctx); diff --git a/gptj.h b/gptj.h index 66f8aa5c..59c0a79c 100644 --- a/gptj.h +++ b/gptj.h @@ -17,9 +17,11 @@ public: void prompt(const std::string &prompt, std::function response, PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f, float temp = 0.0f, int32_t n_batch = 9) override; + void setThreadCount(int32_t n_threads) override; + int32_t threadCount() override; private: GPTJPrivate *d_ptr; }; -#endif // GPTJ_H \ No newline at end of file +#endif // GPTJ_H diff --git a/llm.cpp b/llm.cpp index a73b3fab..d0bb85d5 100644 --- a/llm.cpp +++ b/llm.cpp @@ -62,6 +62,7 @@ bool LLMObject::loadModelPrivate(const QString &modelName) auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); m_llmodel->loadModel(modelName.toStdString(), fin); emit isModelLoadedChanged(); + emit threadCountChanged(); } if (m_llmodel) @@ -70,6 +71,15 @@ bool LLMObject::loadModelPrivate(const QString &modelName) return m_llmodel; } +void LLMObject::setThreadCount(int32_t n_threads) { + m_llmodel->setThreadCount(n_threads); + emit threadCountChanged(); +} + +int32_t LLMObject::threadCount() { + return m_llmodel->threadCount(); +} + bool LLMObject::isModelLoaded() const { return m_llmodel && m_llmodel->isModelLoaded(); @@ -225,6 +235,9 @@ LLM::LLM() connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection); + + connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection); connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection); @@ -233,6 +246,7 @@ LLM::LLM() connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection); connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection); connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection); + connect(this, &LLM::setThreadCountRequested, m_llmodel, &LLMObject::setThreadCount, Qt::QueuedConnection); } bool LLM::isModelLoaded() const @@ -300,6 +314,14 @@ QList LLM::modelList() const return m_llmodel->modelList(); } +void LLM::setThreadCount(int32_t n_threads) { + emit setThreadCountRequested(n_threads); +} + +int32_t LLM::threadCount() { + return m_llmodel->threadCount(); +} + bool LLM::checkForUpdates() const { #if defined(Q_OS_LINUX) diff --git a/llm.h b/llm.h index 0e189e42..2c54e634 100644 --- a/llm.h +++ b/llm.h @@ -12,6 +12,8 @@ class LLMObject : public QObject Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) + Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged) + Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) public: @@ -22,6 +24,8 @@ public: void resetResponse(); void resetContext(); void stopGenerating() { m_stopGenerating = true; } + void setThreadCount(int32_t n_threads); + int32_t threadCount(); QString response() const; QString modelName() const; @@ -42,6 +46,7 @@ Q_SIGNALS: void responseStopped(); void modelNameChanged(); void modelListChanged(); + void threadCountChanged(); private: bool loadModelPrivate(const QString &modelName); @@ -65,6 +70,7 @@ class LLM : public QObject Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) + Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) public: static LLM *globalInstance(); @@ -76,6 +82,8 @@ public: Q_INVOKABLE void resetResponse(); Q_INVOKABLE void resetContext(); Q_INVOKABLE void stopGenerating(); + Q_INVOKABLE void setThreadCount(int32_t n_threads); + Q_INVOKABLE int32_t threadCount(); QString response() const; bool responseInProgress() const { return m_responseInProgress; } @@ -99,6 +107,8 @@ Q_SIGNALS: void modelNameChangeRequested(const QString &modelName); void modelNameChanged(); void modelListChanged(); + void threadCountChanged(); + void setThreadCountRequested(int32_t threadCount); private Q_SLOTS: void responseStarted(); diff --git a/llmodel.h b/llmodel.h index da52d190..3ffb8420 100644 --- a/llmodel.h +++ b/llmodel.h @@ -19,6 +19,8 @@ public: virtual void prompt(const std::string &prompt, std::function response, PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f, float temp = 0.9f, int32_t n_batch = 9) = 0; + virtual void setThreadCount(int32_t n_threads); + virtual int32_t threadCount(); }; -#endif // LLMODEL_H \ No newline at end of file +#endif // LLMODEL_H diff --git a/main.qml b/main.qml index 002f8060..e677db93 100644 --- a/main.qml +++ b/main.qml @@ -107,7 +107,6 @@ Window { property int defaultTopK: 40 property int defaultMaxLength: 4096 property int defaultPromptBatchSize: 9 - property string defaultPromptTemplate: "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. ### Prompt: %1 @@ -141,7 +140,7 @@ Window { GridLayout { columns: 2 - rowSpacing: 10 + rowSpacing: 2 columnSpacing: 10 anchors.fill: parent @@ -278,14 +277,41 @@ Window { } Label { - id: promptTemplateLabel - text: qsTr("Prompt Template:") + id: nThreadsLabel + text: qsTr("CPU Threads") Layout.row: 5 Layout.column: 0 } - Rectangle { + TextField { + text: LLM.threadCount.toString() + ToolTip.text: qsTr("Amount of processing threads to use") + ToolTip.visible: hovered Layout.row: 5 Layout.column: 1 + validator: IntValidator { bottom: 1 } + onAccepted: { + var val = parseInt(text) + if (!isNaN(val)) { + LLM.threadCount = val + focus = false + } else { + text = settingsDialog.nThreads.toString() + } + } + Accessible.role: Accessible.EditableText + Accessible.name: nThreadsLabel.text + Accessible.description: ToolTip.text + } + + Label { + id: promptTemplateLabel + text: qsTr("Prompt Template:") + Layout.row: 6 + Layout.column: 0 + } + Rectangle { + Layout.row: 6 + Layout.column: 1 Layout.fillWidth: true height: 200 color: "transparent" @@ -319,7 +345,7 @@ Window { } } Button { - Layout.row: 6 + Layout.row: 7 Layout.column: 1 Layout.fillWidth: true padding: 15