Add thread count setting

This commit is contained in:
Aaron Miller 2023-04-18 06:46:03 -07:00 committed by AT
parent 169afbdc80
commit 00cb5fe2a5
6 changed files with 78 additions and 8 deletions

View File

@ -659,6 +659,14 @@ bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
return true; return true;
} }
void GPTJ::setThreadCount(int32_t n_threads) {
d_ptr->n_threads = n_threads;
}
int32_t GPTJ::threadCount() {
return d_ptr->n_threads;
}
GPTJ::~GPTJ() GPTJ::~GPTJ()
{ {
ggml_free(d_ptr->model.ctx); ggml_free(d_ptr->model.ctx);

2
gptj.h
View File

@ -17,6 +17,8 @@ public:
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response, void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f, PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
float temp = 0.0f, int32_t n_batch = 9) override; float temp = 0.0f, int32_t n_batch = 9) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() override;
private: private:
GPTJPrivate *d_ptr; GPTJPrivate *d_ptr;

22
llm.cpp
View File

@ -62,6 +62,7 @@ bool LLMObject::loadModelPrivate(const QString &modelName)
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
m_llmodel->loadModel(modelName.toStdString(), fin); m_llmodel->loadModel(modelName.toStdString(), fin);
emit isModelLoadedChanged(); emit isModelLoadedChanged();
emit threadCountChanged();
} }
if (m_llmodel) if (m_llmodel)
@ -70,6 +71,15 @@ bool LLMObject::loadModelPrivate(const QString &modelName)
return m_llmodel; return m_llmodel;
} }
void LLMObject::setThreadCount(int32_t n_threads) {
m_llmodel->setThreadCount(n_threads);
emit threadCountChanged();
}
int32_t LLMObject::threadCount() {
return m_llmodel->threadCount();
}
bool LLMObject::isModelLoaded() const bool LLMObject::isModelLoaded() const
{ {
return m_llmodel && m_llmodel->isModelLoaded(); return m_llmodel && m_llmodel->isModelLoaded();
@ -225,6 +235,9 @@ LLM::LLM()
connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection); connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection);
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection); connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection); connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection);
@ -233,6 +246,7 @@ LLM::LLM()
connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection); connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection); connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection); connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection);
connect(this, &LLM::setThreadCountRequested, m_llmodel, &LLMObject::setThreadCount, Qt::QueuedConnection);
} }
bool LLM::isModelLoaded() const bool LLM::isModelLoaded() const
@ -300,6 +314,14 @@ QList<QString> LLM::modelList() const
return m_llmodel->modelList(); return m_llmodel->modelList();
} }
void LLM::setThreadCount(int32_t n_threads) {
emit setThreadCountRequested(n_threads);
}
int32_t LLM::threadCount() {
return m_llmodel->threadCount();
}
bool LLM::checkForUpdates() const bool LLM::checkForUpdates() const
{ {
#if defined(Q_OS_LINUX) #if defined(Q_OS_LINUX)

10
llm.h
View File

@ -12,6 +12,8 @@ class LLMObject : public QObject
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
public: public:
@ -22,6 +24,8 @@ public:
void resetResponse(); void resetResponse();
void resetContext(); void resetContext();
void stopGenerating() { m_stopGenerating = true; } void stopGenerating() { m_stopGenerating = true; }
void setThreadCount(int32_t n_threads);
int32_t threadCount();
QString response() const; QString response() const;
QString modelName() const; QString modelName() const;
@ -42,6 +46,7 @@ Q_SIGNALS:
void responseStopped(); void responseStopped();
void modelNameChanged(); void modelNameChanged();
void modelListChanged(); void modelListChanged();
void threadCountChanged();
private: private:
bool loadModelPrivate(const QString &modelName); bool loadModelPrivate(const QString &modelName);
@ -65,6 +70,7 @@ class LLM : public QObject
Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged) Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
public: public:
static LLM *globalInstance(); static LLM *globalInstance();
@ -76,6 +82,8 @@ public:
Q_INVOKABLE void resetResponse(); Q_INVOKABLE void resetResponse();
Q_INVOKABLE void resetContext(); Q_INVOKABLE void resetContext();
Q_INVOKABLE void stopGenerating(); Q_INVOKABLE void stopGenerating();
Q_INVOKABLE void setThreadCount(int32_t n_threads);
Q_INVOKABLE int32_t threadCount();
QString response() const; QString response() const;
bool responseInProgress() const { return m_responseInProgress; } bool responseInProgress() const { return m_responseInProgress; }
@ -99,6 +107,8 @@ Q_SIGNALS:
void modelNameChangeRequested(const QString &modelName); void modelNameChangeRequested(const QString &modelName);
void modelNameChanged(); void modelNameChanged();
void modelListChanged(); void modelListChanged();
void threadCountChanged();
void setThreadCountRequested(int32_t threadCount);
private Q_SLOTS: private Q_SLOTS:
void responseStarted(); void responseStarted();

View File

@ -19,6 +19,8 @@ public:
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response, virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f, PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
float temp = 0.9f, int32_t n_batch = 9) = 0; float temp = 0.9f, int32_t n_batch = 9) = 0;
virtual void setThreadCount(int32_t n_threads);
virtual int32_t threadCount();
}; };
#endif // LLMODEL_H #endif // LLMODEL_H

View File

@ -107,7 +107,6 @@ Window {
property int defaultTopK: 40 property int defaultTopK: 40
property int defaultMaxLength: 4096 property int defaultMaxLength: 4096
property int defaultPromptBatchSize: 9 property int defaultPromptBatchSize: 9
property string defaultPromptTemplate: "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. property string defaultPromptTemplate: "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
### Prompt: ### Prompt:
%1 %1
@ -141,7 +140,7 @@ Window {
GridLayout { GridLayout {
columns: 2 columns: 2
rowSpacing: 10 rowSpacing: 2
columnSpacing: 10 columnSpacing: 10
anchors.fill: parent anchors.fill: parent
@ -278,14 +277,41 @@ Window {
} }
Label { Label {
id: promptTemplateLabel id: nThreadsLabel
text: qsTr("Prompt Template:") text: qsTr("CPU Threads")
Layout.row: 5 Layout.row: 5
Layout.column: 0 Layout.column: 0
} }
Rectangle { TextField {
text: LLM.threadCount.toString()
ToolTip.text: qsTr("Amount of processing threads to use")
ToolTip.visible: hovered
Layout.row: 5 Layout.row: 5
Layout.column: 1 Layout.column: 1
validator: IntValidator { bottom: 1 }
onAccepted: {
var val = parseInt(text)
if (!isNaN(val)) {
LLM.threadCount = val
focus = false
} else {
text = settingsDialog.nThreads.toString()
}
}
Accessible.role: Accessible.EditableText
Accessible.name: nThreadsLabel.text
Accessible.description: ToolTip.text
}
Label {
id: promptTemplateLabel
text: qsTr("Prompt Template:")
Layout.row: 6
Layout.column: 0
}
Rectangle {
Layout.row: 6
Layout.column: 1
Layout.fillWidth: true Layout.fillWidth: true
height: 200 height: 200
color: "transparent" color: "transparent"
@ -319,7 +345,7 @@ Window {
} }
} }
Button { Button {
Layout.row: 6 Layout.row: 7
Layout.column: 1 Layout.column: 1
Layout.fillWidth: true Layout.fillWidth: true
padding: 15 padding: 15