add min_p sampling parameter (#2014)

Signed-off-by: Christopher Barrera <cb@arda.tx.rr.com>
Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
This commit is contained in:
chrisbarrera
2024-02-24 16:51:34 -06:00
committed by GitHub
parent a153cc5b25
commit f8b1069a1c
28 changed files with 176 additions and 14 deletions

View File

@@ -568,16 +568,17 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
return promptInternal(collectionList, prompt, promptTemplate, n_predict, top_k, top_p, temp, n_batch,
return promptInternal(collectionList, prompt, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch,
repeat_penalty, repeat_penalty_tokens);
}
bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
{
if (!isModelLoaded())
@@ -608,6 +609,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.min_p = min_p;
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
@@ -1020,6 +1022,7 @@ void ChatLLM::processSystemPrompt()
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
@@ -1028,6 +1031,7 @@ void ChatLLM::processSystemPrompt()
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.min_p = min_p;
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
@@ -1067,6 +1071,7 @@ void ChatLLM::processRestoreStateFromText()
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
@@ -1075,6 +1080,7 @@ void ChatLLM::processRestoreStateFromText()
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.min_p = min_p;
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;

View File

@@ -139,7 +139,7 @@ Q_SIGNALS:
protected:
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens);
bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response);

View File

@@ -1380,6 +1380,7 @@ Window {
MySettings.maxLength,
MySettings.topK,
MySettings.topP,
MySettings.minP,
MySettings.temperature,
MySettings.promptBatchSize,
MySettings.repeatPenalty,

View File

@@ -60,12 +60,23 @@ double ModelInfo::topP() const
return MySettings::globalInstance()->modelTopP(*this);
}
double ModelInfo::minP() const
{
return MySettings::globalInstance()->modelMinP(*this);
}
void ModelInfo::setTopP(double p)
{
if (isClone) MySettings::globalInstance()->setModelTopP(*this, p, isClone /*force*/);
m_topP = p;
}
void ModelInfo::setMinP(double p)
{
if (isClone) MySettings::globalInstance()->setModelMinP(*this, p, isClone /*force*/);
m_minP = p;
}
int ModelInfo::topK() const
{
return MySettings::globalInstance()->modelTopK(*this);
@@ -321,6 +332,7 @@ ModelList::ModelList()
connect(MySettings::globalInstance(), &MySettings::nameChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::topPChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::minPChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings);
@@ -571,6 +583,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
return info->temperature();
case TopPRole:
return info->topP();
case MinPRole:
return info->minP();
case TopKRole:
return info->topK();
case MaxLengthRole:
@@ -700,6 +714,8 @@ void ModelList::updateData(const QString &id, int role, const QVariant &value)
info->setTemperature(value.toDouble()); break;
case TopPRole:
info->setTopP(value.toDouble()); break;
case MinPRole:
info->setMinP(value.toDouble()); break;
case TopKRole:
info->setTopK(value.toInt()); break;
case MaxLengthRole:
@@ -797,6 +813,7 @@ QString ModelList::clone(const ModelInfo &model)
updateData(id, ModelList::OnlineRole, model.isOnline);
updateData(id, ModelList::TemperatureRole, model.temperature());
updateData(id, ModelList::TopPRole, model.topP());
updateData(id, ModelList::MinPRole, model.minP());
updateData(id, ModelList::TopKRole, model.topK());
updateData(id, ModelList::MaxLengthRole, model.maxLength());
updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize());
@@ -1163,6 +1180,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
updateData(id, ModelList::TemperatureRole, obj["temperature"].toDouble());
if (obj.contains("topP"))
updateData(id, ModelList::TopPRole, obj["topP"].toDouble());
if (obj.contains("minP"))
updateData(id, ModelList::MinPRole, obj["minP"].toDouble());
if (obj.contains("topK"))
updateData(id, ModelList::TopKRole, obj["topK"].toInt());
if (obj.contains("maxLength"))
@@ -1287,6 +1306,8 @@ void ModelList::updateModelsFromSettings()
const double temperature = settings.value(g + "/temperature").toDouble();
Q_ASSERT(settings.contains(g + "/topP"));
const double topP = settings.value(g + "/topP").toDouble();
Q_ASSERT(settings.contains(g + "/minP"));
const double minP = settings.value(g + "/minP").toDouble();
Q_ASSERT(settings.contains(g + "/topK"));
const int topK = settings.value(g + "/topK").toInt();
Q_ASSERT(settings.contains(g + "/maxLength"));
@@ -1312,6 +1333,7 @@ void ModelList::updateModelsFromSettings()
updateData(id, ModelList::FilenameRole, filename);
updateData(id, ModelList::TemperatureRole, temperature);
updateData(id, ModelList::TopPRole, topP);
updateData(id, ModelList::MinPRole, minP);
updateData(id, ModelList::TopKRole, topK);
updateData(id, ModelList::MaxLengthRole, maxLength);
updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize);

View File

@@ -36,6 +36,7 @@ struct ModelInfo {
Q_PROPERTY(bool isClone MEMBER isClone)
Q_PROPERTY(double temperature READ temperature WRITE setTemperature)
Q_PROPERTY(double topP READ topP WRITE setTopP)
Q_PROPERTY(double minP READ minP WRITE setMinP)
Q_PROPERTY(int topK READ topK WRITE setTopK)
Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength)
Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize)
@@ -92,6 +93,8 @@ public:
void setTemperature(double t);
double topP() const;
void setTopP(double p);
double minP() const;
void setMinP(double p);
int topK() const;
void setTopK(int k);
int maxLength() const;
@@ -119,6 +122,7 @@ private:
QString m_filename;
double m_temperature = 0.7;
double m_topP = 0.4;
double m_minP = 0.0;
int m_topK = 40;
int m_maxLength = 4096;
int m_promptBatchSize = 128;
@@ -247,6 +251,7 @@ public:
RepeatPenaltyTokensRole,
PromptTemplateRole,
SystemPromptRole,
MinPRole,
};
QHash<int, QByteArray> roleNames() const override
@@ -282,6 +287,7 @@ public:
roles[IsCloneRole] = "isClone";
roles[TemperatureRole] = "temperature";
roles[TopPRole] = "topP";
roles[MinPRole] = "minP";
roles[TopKRole] = "topK";
roles[MaxLengthRole] = "maxLength";
roles[PromptBatchSizeRole] = "promptBatchSize";

View File

@@ -87,6 +87,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &model)
{
setModelTemperature(model, model.m_temperature);
setModelTopP(model, model.m_topP);
setModelMinP(model, model.m_minP);
setModelTopK(model, model.m_topK);;
setModelMaxLength(model, model.m_maxLength);
setModelPromptBatchSize(model, model.m_promptBatchSize);
@@ -201,6 +202,13 @@ double MySettings::modelTopP(const ModelInfo &m) const
return setting.value(QString("model-%1").arg(m.id()) + "/topP", m.m_topP).toDouble();
}
double MySettings::modelMinP(const ModelInfo &m) const
{
QSettings setting;
setting.sync();
return setting.value(QString("model-%1").arg(m.id()) + "/minP", m.m_minP).toDouble();
}
void MySettings::setModelTopP(const ModelInfo &m, double p, bool force)
{
if (modelTopP(m) == p && !force)
@@ -216,6 +224,21 @@ void MySettings::setModelTopP(const ModelInfo &m, double p, bool force)
emit topPChanged(m);
}
void MySettings::setModelMinP(const ModelInfo &m, double p, bool force)
{
if (modelMinP(m) == p && !force)
return;
QSettings setting;
if (m.m_minP == p && !m.isClone)
setting.remove(QString("model-%1").arg(m.id()) + "/minP");
else
setting.setValue(QString("model-%1").arg(m.id()) + "/minP", p);
setting.sync();
if (!force)
emit minPChanged(m);
}
int MySettings::modelTopK(const ModelInfo &m) const
{
QSettings setting;

View File

@@ -47,6 +47,8 @@ public:
Q_INVOKABLE void setModelTemperature(const ModelInfo &m, double t, bool force = false);
double modelTopP(const ModelInfo &m) const;
Q_INVOKABLE void setModelTopP(const ModelInfo &m, double p, bool force = false);
double modelMinP(const ModelInfo &m) const;
Q_INVOKABLE void setModelMinP(const ModelInfo &m, double p, bool force = false);
int modelTopK(const ModelInfo &m) const;
Q_INVOKABLE void setModelTopK(const ModelInfo &m, int k, bool force = false);
int modelMaxLength(const ModelInfo &m) const;
@@ -119,6 +121,7 @@ Q_SIGNALS:
void filenameChanged(const ModelInfo &model);
void temperatureChanged(const ModelInfo &model);
void topPChanged(const ModelInfo &model);
void minPChanged(const ModelInfo &model);
void topKChanged(const ModelInfo &model);
void maxLengthChanged(const ModelInfo &model);
void promptBatchSizeChanged(const ModelInfo &model);

View File

@@ -452,6 +452,50 @@ MySettingsTab {
Accessible.name: topPLabel.text
Accessible.description: ToolTip.text
}
MySettingsLabel {
id: minPLabel
text: qsTr("Min P")
Layout.row: 3
Layout.column: 0
}
MyTextField {
id: minPField
text: root.currentModelInfo.minP
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Sets the minimum relative probability for a token to be considered.")
ToolTip.visible: hovered
Layout.row: 3
Layout.column: 1
validator: DoubleValidator {
locale: "C"
}
Connections {
target: MySettings
function onMinPChanged() {
minPField.text = root.currentModelInfo.minP;
}
}
Connections {
target: root
function onCurrentModelInfoChanged() {
minPField.text = root.currentModelInfo.minP;
}
}
onEditingFinished: {
var val = parseFloat(text)
if (!isNaN(val)) {
MySettings.setModelMinP(root.currentModelInfo, val)
focus = false
} else {
text = root.currentModelInfo.minP
}
}
Accessible.role: Accessible.EditableText
Accessible.name: minPLabel.text
Accessible.description: ToolTip.text
}
MySettingsLabel {
id: topKLabel
visible: !root.currentModelInfo.isOnline
@@ -592,8 +636,8 @@ MySettingsTab {
id: repeatPenaltyLabel
visible: !root.currentModelInfo.isOnline
text: qsTr("Repeat Penalty")
Layout.row: 3
Layout.column: 0
Layout.row: 4
Layout.column: 2
}
MyTextField {
id: repeatPenaltyField
@@ -603,8 +647,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Amount to penalize repetitiveness of the output")
ToolTip.visible: hovered
Layout.row: 3
Layout.column: 1
Layout.row: 4
Layout.column: 3
validator: DoubleValidator {
locale: "C"
}

View File

@@ -205,6 +205,10 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
if (body.contains("top_p"))
top_p = body["top_p"].toDouble();
float min_p = 0.f;
if (body.contains("min_p"))
min_p = body["min_p"].toDouble();
int n = 1;
if (body.contains("n"))
n = body["n"].toInt();
@@ -312,6 +316,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
max_tokens /*n_predict*/,
top_k,
top_p,
min_p,
temperature,
n_batch,
repeat_penalty,