Implement configurable context length (#1749)

This commit is contained in:
Jared Van Bortel
2023-12-16 17:58:15 -05:00
committed by GitHub
parent 7aa0f779de
commit d1c56b8b28
31 changed files with 291 additions and 135 deletions

View File

@@ -20,15 +20,17 @@ ChatGPT::ChatGPT()
{
}
size_t ChatGPT::requiredMem(const std::string &modelPath)
size_t ChatGPT::requiredMem(const std::string &modelPath, int n_ctx)
{
Q_UNUSED(modelPath);
Q_UNUSED(n_ctx);
return 0;
}
bool ChatGPT::loadModel(const std::string &modelPath)
bool ChatGPT::loadModel(const std::string &modelPath, int n_ctx)
{
Q_UNUSED(modelPath);
Q_UNUSED(n_ctx);
return true;
}

View File

@@ -48,9 +48,9 @@ public:
bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override;
bool loadModel(const std::string &modelPath, int n_ctx) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;

View File

@@ -5,7 +5,7 @@
#include <QDataStream>
#define CHAT_FORMAT_MAGIC 0xF5D553CC
#define CHAT_FORMAT_VERSION 6
#define CHAT_FORMAT_VERSION 7
class MyChatListModel: public ChatListModel { };
Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance)

View File

@@ -248,14 +248,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
m_llModelInfo.model = model;
} else {
// TODO: make configurable in UI
auto n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
m_ctx.n_ctx = n_ctx;
std::string buildVariant = "auto";
#if defined(Q_OS_MAC) && defined(__arm__)
if (m_forceMetal)
m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "metal");
else
m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "auto");
#else
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), "auto");
buildVariant = "metal";
#endif
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
if (m_llModelInfo.model) {
// Update the settings that a model is being loaded and update the device list
@@ -267,7 +269,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (requestedDevice == "CPU") {
emit reportFallbackReason(""); // fallback not applicable
} else {
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString());
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString(), n_ctx);
std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
LLModel::GPUDevice *device = nullptr;
@@ -296,14 +298,14 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
// Report which device we're actually using
emit reportDevice(actualDevice);
bool success = m_llModelInfo.model->loadModel(filePath.toStdString());
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx);
if (actualDevice == "CPU") {
// we asked llama.cpp to use the CPU
} else if (!success) {
// llama_init_from_file returned nullptr
emit reportDevice("CPU");
emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)");
success = m_llModelInfo.model->loadModel(filePath.toStdString());
success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx);
} else if (!m_llModelInfo.model->usingGPUDevice()) {
// ggml_vk_init was not called in llama.cpp
// We might have had to fallback to CPU after load if the model is not possible to accelerate
@@ -763,6 +765,8 @@ bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
return false;
}
// this function serialized the cached model state to disk.
// we want to also serialize n_ctx, and read it at load time.
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
{
if (version > 1) {
@@ -790,6 +794,9 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
stream << responseLogits;
}
stream << m_ctx.n_past;
if (version >= 6) {
stream << m_ctx.n_ctx;
}
stream << quint64(m_ctx.logits.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float));
stream << quint64(m_ctx.tokens.size());
@@ -839,6 +846,12 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
stream >> n_past;
if (!discardKV) m_ctx.n_past = n_past;
if (version >= 6) {
uint32_t n_ctx;
stream >> n_ctx;
if (!discardKV) m_ctx.n_ctx = n_ctx;
}
quint64 logitsSize;
stream >> logitsSize;
if (!discardKV) {

View File

@@ -29,8 +29,8 @@ bool EmbeddingLLM::loadModel()
return false;
}
m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto");
bool success = m_model->loadModel(filePath.toStdString());
m_model = LLModel::Implementation::construct(filePath.toStdString());
bool success = m_model->loadModel(filePath.toStdString(), 2048);
if (!success) {
qWarning() << "WARNING: Could not load sbert";
delete m_model;

View File

@@ -97,6 +97,17 @@ void ModelInfo::setPromptBatchSize(int s)
m_promptBatchSize = s;
}
int ModelInfo::contextLength() const
{
return MySettings::globalInstance()->modelContextLength(*this);
}
void ModelInfo::setContextLength(int l)
{
if (isClone) MySettings::globalInstance()->setModelContextLength(*this, l, isClone /*force*/);
m_contextLength = l;
}
double ModelInfo::repeatPenalty() const
{
return MySettings::globalInstance()->modelRepeatPenalty(*this);
@@ -274,6 +285,7 @@ ModelList::ModelList()
connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);;
connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings);
@@ -525,6 +537,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
return info->maxLength();
case PromptBatchSizeRole:
return info->promptBatchSize();
case ContextLengthRole:
return info->contextLength();
case RepeatPenaltyRole:
return info->repeatPenalty();
case RepeatPenaltyTokensRole:
@@ -740,6 +754,7 @@ QString ModelList::clone(const ModelInfo &model)
updateData(id, ModelList::TopKRole, model.topK());
updateData(id, ModelList::MaxLengthRole, model.maxLength());
updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize());
updateData(id, ModelList::ContextLengthRole, model.contextLength());
updateData(id, ModelList::RepeatPenaltyRole, model.repeatPenalty());
updateData(id, ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens());
updateData(id, ModelList::PromptTemplateRole, model.promptTemplate());
@@ -1106,6 +1121,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
updateData(id, ModelList::MaxLengthRole, obj["maxLength"].toInt());
if (obj.contains("promptBatchSize"))
updateData(id, ModelList::PromptBatchSizeRole, obj["promptBatchSize"].toInt());
if (obj.contains("contextLength"))
updateData(id, ModelList::ContextLengthRole, obj["contextLength"].toInt());
if (obj.contains("repeatPenalty"))
updateData(id, ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble());
if (obj.contains("repeatPenaltyTokens"))
@@ -1198,6 +1215,8 @@ void ModelList::updateModelsFromSettings()
const int maxLength = settings.value(g + "/maxLength").toInt();
Q_ASSERT(settings.contains(g + "/promptBatchSize"));
const int promptBatchSize = settings.value(g + "/promptBatchSize").toInt();
Q_ASSERT(settings.contains(g + "/contextLength"));
const int contextLength = settings.value(g + "/contextLength").toInt();
Q_ASSERT(settings.contains(g + "/repeatPenalty"));
const double repeatPenalty = settings.value(g + "/repeatPenalty").toDouble();
Q_ASSERT(settings.contains(g + "/repeatPenaltyTokens"));
@@ -1216,6 +1235,7 @@ void ModelList::updateModelsFromSettings()
updateData(id, ModelList::TopKRole, topK);
updateData(id, ModelList::MaxLengthRole, maxLength);
updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize);
updateData(id, ModelList::ContextLengthRole, contextLength);
updateData(id, ModelList::RepeatPenaltyRole, repeatPenalty);
updateData(id, ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens);
updateData(id, ModelList::PromptTemplateRole, promptTemplate);

View File

@@ -39,6 +39,7 @@ struct ModelInfo {
Q_PROPERTY(int topK READ topK WRITE setTopK)
Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength)
Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize)
Q_PROPERTY(int contextLength READ contextLength WRITE setContextLength)
Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty)
Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens)
Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate)
@@ -94,6 +95,8 @@ public:
void setMaxLength(int l);
int promptBatchSize() const;
void setPromptBatchSize(int s);
int contextLength() const;
void setContextLength(int l);
double repeatPenalty() const;
void setRepeatPenalty(double p);
int repeatPenaltyTokens() const;
@@ -112,6 +115,7 @@ private:
int m_topK = 40;
int m_maxLength = 4096;
int m_promptBatchSize = 128;
int m_contextLength = 2048;
double m_repeatPenalty = 1.18;
int m_repeatPenaltyTokens = 64;
QString m_promptTemplate = "### Human:\n%1\n### Assistant:\n";
@@ -227,6 +231,7 @@ public:
TopKRole,
MaxLengthRole,
PromptBatchSizeRole,
ContextLengthRole,
RepeatPenaltyRole,
RepeatPenaltyTokensRole,
PromptTemplateRole,
@@ -269,6 +274,7 @@ public:
roles[TopKRole] = "topK";
roles[MaxLengthRole] = "maxLength";
roles[PromptBatchSizeRole] = "promptBatchSize";
roles[ContextLengthRole] = "contextLength";
roles[RepeatPenaltyRole] = "repeatPenalty";
roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens";
roles[PromptTemplateRole] = "promptTemplate";

View File

@@ -90,6 +90,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &model)
setModelTopK(model, model.m_topK);;
setModelMaxLength(model, model.m_maxLength);
setModelPromptBatchSize(model, model.m_promptBatchSize);
setModelContextLength(model, model.m_contextLength);
setModelRepeatPenalty(model, model.m_repeatPenalty);
setModelRepeatPenaltyTokens(model, model.m_repeatPenaltyTokens);
setModelPromptTemplate(model, model.m_promptTemplate);
@@ -280,6 +281,28 @@ void MySettings::setModelPromptBatchSize(const ModelInfo &m, int s, bool force)
emit promptBatchSizeChanged(m);
}
int MySettings::modelContextLength(const ModelInfo &m) const
{
QSettings setting;
setting.sync();
return setting.value(QString("model-%1").arg(m.id()) + "/contextLength", m.m_contextLength).toInt();
}
void MySettings::setModelContextLength(const ModelInfo &m, int l, bool force)
{
if (modelContextLength(m) == l && !force)
return;
QSettings setting;
if (m.m_contextLength == l && !m.isClone)
setting.remove(QString("model-%1").arg(m.id()) + "/contextLength");
else
setting.setValue(QString("model-%1").arg(m.id()) + "/contextLength", l);
setting.sync();
if (!force)
emit contextLengthChanged(m);
}
double MySettings::modelRepeatPenalty(const ModelInfo &m) const
{
QSettings setting;

View File

@@ -1,6 +1,8 @@
#ifndef MYSETTINGS_H
#define MYSETTINGS_H
#include <cstdint>
#include <QObject>
#include <QMutex>
@@ -59,6 +61,8 @@ public:
Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &m, const QString &t, bool force = false);
QString modelSystemPrompt(const ModelInfo &m) const;
Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &m, const QString &p, bool force = false);
int modelContextLength(const ModelInfo &m) const;
Q_INVOKABLE void setModelContextLength(const ModelInfo &m, int s, bool force = false);
// Application settings
int threadCount() const;
@@ -79,6 +83,8 @@ public:
void setForceMetal(bool b);
QString device() const;
void setDevice(const QString &u);
int32_t contextLength() const;
void setContextLength(int32_t value);
// Release/Download settings
QString lastVersionStarted() const;
@@ -114,6 +120,7 @@ Q_SIGNALS:
void topKChanged(const ModelInfo &model);
void maxLengthChanged(const ModelInfo &model);
void promptBatchSizeChanged(const ModelInfo &model);
void contextLengthChanged(const ModelInfo &model);
void repeatPenaltyChanged(const ModelInfo &model);
void repeatPenaltyTokensChanged(const ModelInfo &model);
void promptTemplateChanged(const ModelInfo &model);

View File

@@ -349,13 +349,61 @@ MySettingsTab {
rowSpacing: 10
columnSpacing: 10
Label {
id: contextLengthLabel
visible: !root.currentModelInfo.isChatGPT
text: qsTr("Context Length:")
font.pixelSize: theme.fontSizeLarge
color: theme.textColor
Layout.row: 0
Layout.column: 0
}
MyTextField {
id: contextLengthField
visible: !root.currentModelInfo.isChatGPT
text: root.currentModelInfo.contextLength
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Maximum combined prompt/response tokens before information is lost.\nUsing more context than the model was trained on will yield poor results.\nNOTE: Does not take effect until you RESTART GPT4All or SWITCH MODELS.")
ToolTip.visible: hovered
Layout.row: 0
Layout.column: 1
validator: IntValidator {
bottom: 1
}
Connections {
target: MySettings
function onContextLengthChanged() {
contextLengthField.text = root.currentModelInfo.contextLength;
}
}
Connections {
target: root
function onCurrentModelInfoChanged() {
contextLengthField.text = root.currentModelInfo.contextLength;
}
}
onEditingFinished: {
var val = parseInt(text)
if (!isNaN(val)) {
MySettings.setModelContextLength(root.currentModelInfo, val)
focus = false
} else {
text = root.currentModelInfo.contextLength
}
}
Accessible.role: Accessible.EditableText
Accessible.name: contextLengthLabel.text
Accessible.description: ToolTip.text
}
Label {
id: tempLabel
text: qsTr("Temperature:")
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
Layout.row: 0
Layout.column: 0
Layout.row: 1
Layout.column: 2
}
MyTextField {
@@ -365,8 +413,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Temperature increases the chances of choosing less likely tokens.\nNOTE: Higher temperature gives more creative but less predictable outputs.")
ToolTip.visible: hovered
Layout.row: 0
Layout.column: 1
Layout.row: 1
Layout.column: 3
validator: DoubleValidator {
locale: "C"
}
@@ -400,8 +448,8 @@ MySettingsTab {
text: qsTr("Top P:")
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
Layout.row: 0
Layout.column: 2
Layout.row: 2
Layout.column: 0
}
MyTextField {
id: topPField
@@ -410,8 +458,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Only the most likely tokens up to a total probability of top_p can be chosen.\nNOTE: Prevents choosing highly unlikely tokens, aka Nucleus Sampling")
ToolTip.visible: hovered
Layout.row: 0
Layout.column: 3
Layout.row: 2
Layout.column: 1
validator: DoubleValidator {
locale: "C"
}
@@ -446,8 +494,8 @@ MySettingsTab {
text: qsTr("Top K:")
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
Layout.row: 1
Layout.column: 0
Layout.row: 2
Layout.column: 2
}
MyTextField {
id: topKField
@@ -457,8 +505,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Only the top K most likely tokens will be chosen from")
ToolTip.visible: hovered
Layout.row: 1
Layout.column: 1
Layout.row: 2
Layout.column: 3
validator: IntValidator {
bottom: 1
}
@@ -493,7 +541,7 @@ MySettingsTab {
text: qsTr("Max Length:")
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
Layout.row: 1
Layout.row: 0
Layout.column: 2
}
MyTextField {
@@ -504,7 +552,7 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Maximum length of response in tokens")
ToolTip.visible: hovered
Layout.row: 1
Layout.row: 0
Layout.column: 3
validator: IntValidator {
bottom: 1
@@ -541,7 +589,7 @@ MySettingsTab {
text: qsTr("Prompt Batch Size:")
font.pixelSize: theme.fontSizeLarge
color: theme.textColor
Layout.row: 2
Layout.row: 1
Layout.column: 0
}
MyTextField {
@@ -552,7 +600,7 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Amount of prompt tokens to process at once.\nNOTE: Higher values can speed up reading prompts but will use more RAM")
ToolTip.visible: hovered
Layout.row: 2
Layout.row: 1
Layout.column: 1
validator: IntValidator {
bottom: 1
@@ -588,8 +636,8 @@ MySettingsTab {
text: qsTr("Repeat Penalty:")
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
Layout.row: 2
Layout.column: 2
Layout.row: 3
Layout.column: 0
}
MyTextField {
id: repeatPenaltyField
@@ -599,8 +647,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Amount to penalize repetitiveness of the output")
ToolTip.visible: hovered
Layout.row: 2
Layout.column: 3
Layout.row: 3
Layout.column: 1
validator: DoubleValidator {
locale: "C"
}
@@ -636,7 +684,7 @@ MySettingsTab {
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
Layout.row: 3
Layout.column: 0
Layout.column: 2
}
MyTextField {
id: repeatPenaltyTokenField
@@ -647,7 +695,7 @@ MySettingsTab {
ToolTip.text: qsTr("How far back in output to apply repeat penalty")
ToolTip.visible: hovered
Layout.row: 3
Layout.column: 1
Layout.column: 3
validator: IntValidator {
bottom: 1
}