mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-27 07:48:19 +00:00
Add thread count setting
This commit is contained in:
parent
169afbdc80
commit
00cb5fe2a5
8
gptj.cpp
8
gptj.cpp
@ -659,6 +659,14 @@ bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GPTJ::setThreadCount(int32_t n_threads) {
|
||||||
|
d_ptr->n_threads = n_threads;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t GPTJ::threadCount() {
|
||||||
|
return d_ptr->n_threads;
|
||||||
|
}
|
||||||
|
|
||||||
GPTJ::~GPTJ()
|
GPTJ::~GPTJ()
|
||||||
{
|
{
|
||||||
ggml_free(d_ptr->model.ctx);
|
ggml_free(d_ptr->model.ctx);
|
||||||
|
4
gptj.h
4
gptj.h
@ -17,9 +17,11 @@ public:
|
|||||||
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
|
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
|
||||||
float temp = 0.0f, int32_t n_batch = 9) override;
|
float temp = 0.0f, int32_t n_batch = 9) override;
|
||||||
|
void setThreadCount(int32_t n_threads) override;
|
||||||
|
int32_t threadCount() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GPTJPrivate *d_ptr;
|
GPTJPrivate *d_ptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // GPTJ_H
|
#endif // GPTJ_H
|
||||||
|
22
llm.cpp
22
llm.cpp
@ -62,6 +62,7 @@ bool LLMObject::loadModelPrivate(const QString &modelName)
|
|||||||
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
||||||
m_llmodel->loadModel(modelName.toStdString(), fin);
|
m_llmodel->loadModel(modelName.toStdString(), fin);
|
||||||
emit isModelLoadedChanged();
|
emit isModelLoadedChanged();
|
||||||
|
emit threadCountChanged();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (m_llmodel)
|
if (m_llmodel)
|
||||||
@ -70,6 +71,15 @@ bool LLMObject::loadModelPrivate(const QString &modelName)
|
|||||||
return m_llmodel;
|
return m_llmodel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LLMObject::setThreadCount(int32_t n_threads) {
|
||||||
|
m_llmodel->setThreadCount(n_threads);
|
||||||
|
emit threadCountChanged();
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t LLMObject::threadCount() {
|
||||||
|
return m_llmodel->threadCount();
|
||||||
|
}
|
||||||
|
|
||||||
bool LLMObject::isModelLoaded() const
|
bool LLMObject::isModelLoaded() const
|
||||||
{
|
{
|
||||||
return m_llmodel && m_llmodel->isModelLoaded();
|
return m_llmodel && m_llmodel->isModelLoaded();
|
||||||
@ -225,6 +235,9 @@ LLM::LLM()
|
|||||||
connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection);
|
connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
|
||||||
connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
|
||||||
|
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection);
|
||||||
|
|
||||||
|
|
||||||
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
|
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
|
||||||
connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection);
|
connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection);
|
||||||
|
|
||||||
@ -233,6 +246,7 @@ LLM::LLM()
|
|||||||
connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection);
|
connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection);
|
||||||
connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection);
|
connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection);
|
||||||
connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection);
|
connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection);
|
||||||
|
connect(this, &LLM::setThreadCountRequested, m_llmodel, &LLMObject::setThreadCount, Qt::QueuedConnection);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LLM::isModelLoaded() const
|
bool LLM::isModelLoaded() const
|
||||||
@ -300,6 +314,14 @@ QList<QString> LLM::modelList() const
|
|||||||
return m_llmodel->modelList();
|
return m_llmodel->modelList();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LLM::setThreadCount(int32_t n_threads) {
|
||||||
|
emit setThreadCountRequested(n_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t LLM::threadCount() {
|
||||||
|
return m_llmodel->threadCount();
|
||||||
|
}
|
||||||
|
|
||||||
bool LLM::checkForUpdates() const
|
bool LLM::checkForUpdates() const
|
||||||
{
|
{
|
||||||
#if defined(Q_OS_LINUX)
|
#if defined(Q_OS_LINUX)
|
||||||
|
10
llm.h
10
llm.h
@ -12,6 +12,8 @@ class LLMObject : public QObject
|
|||||||
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
|
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
|
||||||
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
||||||
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
||||||
|
Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged)
|
||||||
|
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
@ -22,6 +24,8 @@ public:
|
|||||||
void resetResponse();
|
void resetResponse();
|
||||||
void resetContext();
|
void resetContext();
|
||||||
void stopGenerating() { m_stopGenerating = true; }
|
void stopGenerating() { m_stopGenerating = true; }
|
||||||
|
void setThreadCount(int32_t n_threads);
|
||||||
|
int32_t threadCount();
|
||||||
|
|
||||||
QString response() const;
|
QString response() const;
|
||||||
QString modelName() const;
|
QString modelName() const;
|
||||||
@ -42,6 +46,7 @@ Q_SIGNALS:
|
|||||||
void responseStopped();
|
void responseStopped();
|
||||||
void modelNameChanged();
|
void modelNameChanged();
|
||||||
void modelListChanged();
|
void modelListChanged();
|
||||||
|
void threadCountChanged();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool loadModelPrivate(const QString &modelName);
|
bool loadModelPrivate(const QString &modelName);
|
||||||
@ -65,6 +70,7 @@ class LLM : public QObject
|
|||||||
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
||||||
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
||||||
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
||||||
|
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||||
public:
|
public:
|
||||||
|
|
||||||
static LLM *globalInstance();
|
static LLM *globalInstance();
|
||||||
@ -76,6 +82,8 @@ public:
|
|||||||
Q_INVOKABLE void resetResponse();
|
Q_INVOKABLE void resetResponse();
|
||||||
Q_INVOKABLE void resetContext();
|
Q_INVOKABLE void resetContext();
|
||||||
Q_INVOKABLE void stopGenerating();
|
Q_INVOKABLE void stopGenerating();
|
||||||
|
Q_INVOKABLE void setThreadCount(int32_t n_threads);
|
||||||
|
Q_INVOKABLE int32_t threadCount();
|
||||||
|
|
||||||
QString response() const;
|
QString response() const;
|
||||||
bool responseInProgress() const { return m_responseInProgress; }
|
bool responseInProgress() const { return m_responseInProgress; }
|
||||||
@ -99,6 +107,8 @@ Q_SIGNALS:
|
|||||||
void modelNameChangeRequested(const QString &modelName);
|
void modelNameChangeRequested(const QString &modelName);
|
||||||
void modelNameChanged();
|
void modelNameChanged();
|
||||||
void modelListChanged();
|
void modelListChanged();
|
||||||
|
void threadCountChanged();
|
||||||
|
void setThreadCountRequested(int32_t threadCount);
|
||||||
|
|
||||||
private Q_SLOTS:
|
private Q_SLOTS:
|
||||||
void responseStarted();
|
void responseStarted();
|
||||||
|
@ -19,6 +19,8 @@ public:
|
|||||||
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
|
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
|
||||||
float temp = 0.9f, int32_t n_batch = 9) = 0;
|
float temp = 0.9f, int32_t n_batch = 9) = 0;
|
||||||
|
virtual void setThreadCount(int32_t n_threads);
|
||||||
|
virtual int32_t threadCount();
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // LLMODEL_H
|
#endif // LLMODEL_H
|
||||||
|
38
main.qml
38
main.qml
@ -107,7 +107,6 @@ Window {
|
|||||||
property int defaultTopK: 40
|
property int defaultTopK: 40
|
||||||
property int defaultMaxLength: 4096
|
property int defaultMaxLength: 4096
|
||||||
property int defaultPromptBatchSize: 9
|
property int defaultPromptBatchSize: 9
|
||||||
|
|
||||||
property string defaultPromptTemplate: "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
|
property string defaultPromptTemplate: "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
|
||||||
### Prompt:
|
### Prompt:
|
||||||
%1
|
%1
|
||||||
@ -141,7 +140,7 @@ Window {
|
|||||||
|
|
||||||
GridLayout {
|
GridLayout {
|
||||||
columns: 2
|
columns: 2
|
||||||
rowSpacing: 10
|
rowSpacing: 2
|
||||||
columnSpacing: 10
|
columnSpacing: 10
|
||||||
anchors.fill: parent
|
anchors.fill: parent
|
||||||
|
|
||||||
@ -278,14 +277,41 @@ Window {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Label {
|
Label {
|
||||||
id: promptTemplateLabel
|
id: nThreadsLabel
|
||||||
text: qsTr("Prompt Template:")
|
text: qsTr("CPU Threads")
|
||||||
Layout.row: 5
|
Layout.row: 5
|
||||||
Layout.column: 0
|
Layout.column: 0
|
||||||
}
|
}
|
||||||
Rectangle {
|
TextField {
|
||||||
|
text: LLM.threadCount.toString()
|
||||||
|
ToolTip.text: qsTr("Amount of processing threads to use")
|
||||||
|
ToolTip.visible: hovered
|
||||||
Layout.row: 5
|
Layout.row: 5
|
||||||
Layout.column: 1
|
Layout.column: 1
|
||||||
|
validator: IntValidator { bottom: 1 }
|
||||||
|
onAccepted: {
|
||||||
|
var val = parseInt(text)
|
||||||
|
if (!isNaN(val)) {
|
||||||
|
LLM.threadCount = val
|
||||||
|
focus = false
|
||||||
|
} else {
|
||||||
|
text = settingsDialog.nThreads.toString()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Accessible.role: Accessible.EditableText
|
||||||
|
Accessible.name: nThreadsLabel.text
|
||||||
|
Accessible.description: ToolTip.text
|
||||||
|
}
|
||||||
|
|
||||||
|
Label {
|
||||||
|
id: promptTemplateLabel
|
||||||
|
text: qsTr("Prompt Template:")
|
||||||
|
Layout.row: 6
|
||||||
|
Layout.column: 0
|
||||||
|
}
|
||||||
|
Rectangle {
|
||||||
|
Layout.row: 6
|
||||||
|
Layout.column: 1
|
||||||
Layout.fillWidth: true
|
Layout.fillWidth: true
|
||||||
height: 200
|
height: 200
|
||||||
color: "transparent"
|
color: "transparent"
|
||||||
@ -319,7 +345,7 @@ Window {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Button {
|
Button {
|
||||||
Layout.row: 6
|
Layout.row: 7
|
||||||
Layout.column: 1
|
Layout.column: 1
|
||||||
Layout.fillWidth: true
|
Layout.fillWidth: true
|
||||||
padding: 15
|
padding: 15
|
||||||
|
Loading…
Reference in New Issue
Block a user