mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-02 00:57:09 +00:00
add min_p sampling parameter (#2014)
Signed-off-by: Christopher Barrera <cb@arda.tx.rr.com> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
This commit is contained in:
@@ -568,16 +568,17 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
|
||||
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
|
||||
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
|
||||
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
|
||||
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
|
||||
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
|
||||
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
|
||||
return promptInternal(collectionList, prompt, promptTemplate, n_predict, top_k, top_p, temp, n_batch,
|
||||
return promptInternal(collectionList, prompt, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch,
|
||||
repeat_penalty, repeat_penalty_tokens);
|
||||
}
|
||||
|
||||
bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
||||
int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t repeat_penalty_tokens)
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
@@ -608,6 +609,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
m_ctx.n_predict = n_predict;
|
||||
m_ctx.top_k = top_k;
|
||||
m_ctx.top_p = top_p;
|
||||
m_ctx.min_p = min_p;
|
||||
m_ctx.temp = temp;
|
||||
m_ctx.n_batch = n_batch;
|
||||
m_ctx.repeat_penalty = repeat_penalty;
|
||||
@@ -1020,6 +1022,7 @@ void ChatLLM::processSystemPrompt()
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
|
||||
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
|
||||
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
|
||||
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
|
||||
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
|
||||
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
|
||||
@@ -1028,6 +1031,7 @@ void ChatLLM::processSystemPrompt()
|
||||
m_ctx.n_predict = n_predict;
|
||||
m_ctx.top_k = top_k;
|
||||
m_ctx.top_p = top_p;
|
||||
m_ctx.min_p = min_p;
|
||||
m_ctx.temp = temp;
|
||||
m_ctx.n_batch = n_batch;
|
||||
m_ctx.repeat_penalty = repeat_penalty;
|
||||
@@ -1067,6 +1071,7 @@ void ChatLLM::processRestoreStateFromText()
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
|
||||
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
|
||||
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
|
||||
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
|
||||
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
|
||||
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
|
||||
@@ -1075,6 +1080,7 @@ void ChatLLM::processRestoreStateFromText()
|
||||
m_ctx.n_predict = n_predict;
|
||||
m_ctx.top_k = top_k;
|
||||
m_ctx.top_p = top_p;
|
||||
m_ctx.min_p = min_p;
|
||||
m_ctx.temp = temp;
|
||||
m_ctx.n_batch = n_batch;
|
||||
m_ctx.repeat_penalty = repeat_penalty;
|
||||
|
@@ -139,7 +139,7 @@ Q_SIGNALS:
|
||||
|
||||
protected:
|
||||
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
||||
int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t repeat_penalty_tokens);
|
||||
bool handlePrompt(int32_t token);
|
||||
bool handleResponse(int32_t token, const std::string &response);
|
||||
|
@@ -1380,6 +1380,7 @@ Window {
|
||||
MySettings.maxLength,
|
||||
MySettings.topK,
|
||||
MySettings.topP,
|
||||
MySettings.minP,
|
||||
MySettings.temperature,
|
||||
MySettings.promptBatchSize,
|
||||
MySettings.repeatPenalty,
|
||||
|
@@ -60,12 +60,23 @@ double ModelInfo::topP() const
|
||||
return MySettings::globalInstance()->modelTopP(*this);
|
||||
}
|
||||
|
||||
double ModelInfo::minP() const
|
||||
{
|
||||
return MySettings::globalInstance()->modelMinP(*this);
|
||||
}
|
||||
|
||||
void ModelInfo::setTopP(double p)
|
||||
{
|
||||
if (isClone) MySettings::globalInstance()->setModelTopP(*this, p, isClone /*force*/);
|
||||
m_topP = p;
|
||||
}
|
||||
|
||||
void ModelInfo::setMinP(double p)
|
||||
{
|
||||
if (isClone) MySettings::globalInstance()->setModelMinP(*this, p, isClone /*force*/);
|
||||
m_minP = p;
|
||||
}
|
||||
|
||||
int ModelInfo::topK() const
|
||||
{
|
||||
return MySettings::globalInstance()->modelTopK(*this);
|
||||
@@ -321,6 +332,7 @@ ModelList::ModelList()
|
||||
connect(MySettings::globalInstance(), &MySettings::nameChanged, this, &ModelList::updateDataForSettings);
|
||||
connect(MySettings::globalInstance(), &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings);
|
||||
connect(MySettings::globalInstance(), &MySettings::topPChanged, this, &ModelList::updateDataForSettings);
|
||||
connect(MySettings::globalInstance(), &MySettings::minPChanged, this, &ModelList::updateDataForSettings);
|
||||
connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings);
|
||||
connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings);
|
||||
connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings);
|
||||
@@ -571,6 +583,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
|
||||
return info->temperature();
|
||||
case TopPRole:
|
||||
return info->topP();
|
||||
case MinPRole:
|
||||
return info->minP();
|
||||
case TopKRole:
|
||||
return info->topK();
|
||||
case MaxLengthRole:
|
||||
@@ -700,6 +714,8 @@ void ModelList::updateData(const QString &id, int role, const QVariant &value)
|
||||
info->setTemperature(value.toDouble()); break;
|
||||
case TopPRole:
|
||||
info->setTopP(value.toDouble()); break;
|
||||
case MinPRole:
|
||||
info->setMinP(value.toDouble()); break;
|
||||
case TopKRole:
|
||||
info->setTopK(value.toInt()); break;
|
||||
case MaxLengthRole:
|
||||
@@ -797,6 +813,7 @@ QString ModelList::clone(const ModelInfo &model)
|
||||
updateData(id, ModelList::OnlineRole, model.isOnline);
|
||||
updateData(id, ModelList::TemperatureRole, model.temperature());
|
||||
updateData(id, ModelList::TopPRole, model.topP());
|
||||
updateData(id, ModelList::MinPRole, model.minP());
|
||||
updateData(id, ModelList::TopKRole, model.topK());
|
||||
updateData(id, ModelList::MaxLengthRole, model.maxLength());
|
||||
updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize());
|
||||
@@ -1163,6 +1180,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
|
||||
updateData(id, ModelList::TemperatureRole, obj["temperature"].toDouble());
|
||||
if (obj.contains("topP"))
|
||||
updateData(id, ModelList::TopPRole, obj["topP"].toDouble());
|
||||
if (obj.contains("minP"))
|
||||
updateData(id, ModelList::MinPRole, obj["minP"].toDouble());
|
||||
if (obj.contains("topK"))
|
||||
updateData(id, ModelList::TopKRole, obj["topK"].toInt());
|
||||
if (obj.contains("maxLength"))
|
||||
@@ -1287,6 +1306,8 @@ void ModelList::updateModelsFromSettings()
|
||||
const double temperature = settings.value(g + "/temperature").toDouble();
|
||||
Q_ASSERT(settings.contains(g + "/topP"));
|
||||
const double topP = settings.value(g + "/topP").toDouble();
|
||||
Q_ASSERT(settings.contains(g + "/minP"));
|
||||
const double minP = settings.value(g + "/minP").toDouble();
|
||||
Q_ASSERT(settings.contains(g + "/topK"));
|
||||
const int topK = settings.value(g + "/topK").toInt();
|
||||
Q_ASSERT(settings.contains(g + "/maxLength"));
|
||||
@@ -1312,6 +1333,7 @@ void ModelList::updateModelsFromSettings()
|
||||
updateData(id, ModelList::FilenameRole, filename);
|
||||
updateData(id, ModelList::TemperatureRole, temperature);
|
||||
updateData(id, ModelList::TopPRole, topP);
|
||||
updateData(id, ModelList::MinPRole, minP);
|
||||
updateData(id, ModelList::TopKRole, topK);
|
||||
updateData(id, ModelList::MaxLengthRole, maxLength);
|
||||
updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize);
|
||||
|
@@ -36,6 +36,7 @@ struct ModelInfo {
|
||||
Q_PROPERTY(bool isClone MEMBER isClone)
|
||||
Q_PROPERTY(double temperature READ temperature WRITE setTemperature)
|
||||
Q_PROPERTY(double topP READ topP WRITE setTopP)
|
||||
Q_PROPERTY(double minP READ minP WRITE setMinP)
|
||||
Q_PROPERTY(int topK READ topK WRITE setTopK)
|
||||
Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength)
|
||||
Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize)
|
||||
@@ -92,6 +93,8 @@ public:
|
||||
void setTemperature(double t);
|
||||
double topP() const;
|
||||
void setTopP(double p);
|
||||
double minP() const;
|
||||
void setMinP(double p);
|
||||
int topK() const;
|
||||
void setTopK(int k);
|
||||
int maxLength() const;
|
||||
@@ -119,6 +122,7 @@ private:
|
||||
QString m_filename;
|
||||
double m_temperature = 0.7;
|
||||
double m_topP = 0.4;
|
||||
double m_minP = 0.0;
|
||||
int m_topK = 40;
|
||||
int m_maxLength = 4096;
|
||||
int m_promptBatchSize = 128;
|
||||
@@ -247,6 +251,7 @@ public:
|
||||
RepeatPenaltyTokensRole,
|
||||
PromptTemplateRole,
|
||||
SystemPromptRole,
|
||||
MinPRole,
|
||||
};
|
||||
|
||||
QHash<int, QByteArray> roleNames() const override
|
||||
@@ -282,6 +287,7 @@ public:
|
||||
roles[IsCloneRole] = "isClone";
|
||||
roles[TemperatureRole] = "temperature";
|
||||
roles[TopPRole] = "topP";
|
||||
roles[MinPRole] = "minP";
|
||||
roles[TopKRole] = "topK";
|
||||
roles[MaxLengthRole] = "maxLength";
|
||||
roles[PromptBatchSizeRole] = "promptBatchSize";
|
||||
|
@@ -87,6 +87,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &model)
|
||||
{
|
||||
setModelTemperature(model, model.m_temperature);
|
||||
setModelTopP(model, model.m_topP);
|
||||
setModelMinP(model, model.m_minP);
|
||||
setModelTopK(model, model.m_topK);;
|
||||
setModelMaxLength(model, model.m_maxLength);
|
||||
setModelPromptBatchSize(model, model.m_promptBatchSize);
|
||||
@@ -201,6 +202,13 @@ double MySettings::modelTopP(const ModelInfo &m) const
|
||||
return setting.value(QString("model-%1").arg(m.id()) + "/topP", m.m_topP).toDouble();
|
||||
}
|
||||
|
||||
double MySettings::modelMinP(const ModelInfo &m) const
|
||||
{
|
||||
QSettings setting;
|
||||
setting.sync();
|
||||
return setting.value(QString("model-%1").arg(m.id()) + "/minP", m.m_minP).toDouble();
|
||||
}
|
||||
|
||||
void MySettings::setModelTopP(const ModelInfo &m, double p, bool force)
|
||||
{
|
||||
if (modelTopP(m) == p && !force)
|
||||
@@ -216,6 +224,21 @@ void MySettings::setModelTopP(const ModelInfo &m, double p, bool force)
|
||||
emit topPChanged(m);
|
||||
}
|
||||
|
||||
void MySettings::setModelMinP(const ModelInfo &m, double p, bool force)
|
||||
{
|
||||
if (modelMinP(m) == p && !force)
|
||||
return;
|
||||
|
||||
QSettings setting;
|
||||
if (m.m_minP == p && !m.isClone)
|
||||
setting.remove(QString("model-%1").arg(m.id()) + "/minP");
|
||||
else
|
||||
setting.setValue(QString("model-%1").arg(m.id()) + "/minP", p);
|
||||
setting.sync();
|
||||
if (!force)
|
||||
emit minPChanged(m);
|
||||
}
|
||||
|
||||
int MySettings::modelTopK(const ModelInfo &m) const
|
||||
{
|
||||
QSettings setting;
|
||||
|
@@ -47,6 +47,8 @@ public:
|
||||
Q_INVOKABLE void setModelTemperature(const ModelInfo &m, double t, bool force = false);
|
||||
double modelTopP(const ModelInfo &m) const;
|
||||
Q_INVOKABLE void setModelTopP(const ModelInfo &m, double p, bool force = false);
|
||||
double modelMinP(const ModelInfo &m) const;
|
||||
Q_INVOKABLE void setModelMinP(const ModelInfo &m, double p, bool force = false);
|
||||
int modelTopK(const ModelInfo &m) const;
|
||||
Q_INVOKABLE void setModelTopK(const ModelInfo &m, int k, bool force = false);
|
||||
int modelMaxLength(const ModelInfo &m) const;
|
||||
@@ -119,6 +121,7 @@ Q_SIGNALS:
|
||||
void filenameChanged(const ModelInfo &model);
|
||||
void temperatureChanged(const ModelInfo &model);
|
||||
void topPChanged(const ModelInfo &model);
|
||||
void minPChanged(const ModelInfo &model);
|
||||
void topKChanged(const ModelInfo &model);
|
||||
void maxLengthChanged(const ModelInfo &model);
|
||||
void promptBatchSizeChanged(const ModelInfo &model);
|
||||
|
@@ -452,6 +452,50 @@ MySettingsTab {
|
||||
Accessible.name: topPLabel.text
|
||||
Accessible.description: ToolTip.text
|
||||
}
|
||||
MySettingsLabel {
|
||||
id: minPLabel
|
||||
text: qsTr("Min P")
|
||||
Layout.row: 3
|
||||
Layout.column: 0
|
||||
}
|
||||
MyTextField {
|
||||
id: minPField
|
||||
text: root.currentModelInfo.minP
|
||||
color: theme.textColor
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
ToolTip.text: qsTr("Sets the minimum relative probability for a token to be considered.")
|
||||
ToolTip.visible: hovered
|
||||
Layout.row: 3
|
||||
Layout.column: 1
|
||||
validator: DoubleValidator {
|
||||
locale: "C"
|
||||
}
|
||||
Connections {
|
||||
target: MySettings
|
||||
function onMinPChanged() {
|
||||
minPField.text = root.currentModelInfo.minP;
|
||||
}
|
||||
}
|
||||
Connections {
|
||||
target: root
|
||||
function onCurrentModelInfoChanged() {
|
||||
minPField.text = root.currentModelInfo.minP;
|
||||
}
|
||||
}
|
||||
onEditingFinished: {
|
||||
var val = parseFloat(text)
|
||||
if (!isNaN(val)) {
|
||||
MySettings.setModelMinP(root.currentModelInfo, val)
|
||||
focus = false
|
||||
} else {
|
||||
text = root.currentModelInfo.minP
|
||||
}
|
||||
}
|
||||
Accessible.role: Accessible.EditableText
|
||||
Accessible.name: minPLabel.text
|
||||
Accessible.description: ToolTip.text
|
||||
}
|
||||
|
||||
MySettingsLabel {
|
||||
id: topKLabel
|
||||
visible: !root.currentModelInfo.isOnline
|
||||
@@ -592,8 +636,8 @@ MySettingsTab {
|
||||
id: repeatPenaltyLabel
|
||||
visible: !root.currentModelInfo.isOnline
|
||||
text: qsTr("Repeat Penalty")
|
||||
Layout.row: 3
|
||||
Layout.column: 0
|
||||
Layout.row: 4
|
||||
Layout.column: 2
|
||||
}
|
||||
MyTextField {
|
||||
id: repeatPenaltyField
|
||||
@@ -603,8 +647,8 @@ MySettingsTab {
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
ToolTip.text: qsTr("Amount to penalize repetitiveness of the output")
|
||||
ToolTip.visible: hovered
|
||||
Layout.row: 3
|
||||
Layout.column: 1
|
||||
Layout.row: 4
|
||||
Layout.column: 3
|
||||
validator: DoubleValidator {
|
||||
locale: "C"
|
||||
}
|
||||
|
@@ -205,6 +205,10 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
|
||||
if (body.contains("top_p"))
|
||||
top_p = body["top_p"].toDouble();
|
||||
|
||||
float min_p = 0.f;
|
||||
if (body.contains("min_p"))
|
||||
min_p = body["min_p"].toDouble();
|
||||
|
||||
int n = 1;
|
||||
if (body.contains("n"))
|
||||
n = body["n"].toInt();
|
||||
@@ -312,6 +316,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
|
||||
max_tokens /*n_predict*/,
|
||||
top_k,
|
||||
top_p,
|
||||
min_p,
|
||||
temperature,
|
||||
n_batch,
|
||||
repeat_penalty,
|
||||
|
Reference in New Issue
Block a user