Use the token cache to infer greater n_past and reuse results (#3073)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-10-31 11:19:12 -04:00
committed by GitHub
parent 62cab695eb
commit f07e2e63df
15 changed files with 320 additions and 169 deletions

View File

@@ -33,6 +33,7 @@
#include <functional>
#include <limits>
#include <optional>
#include <span>
#include <string_view>
#include <utility>
#include <vector>
@@ -404,7 +405,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
QString requestedDevice = MySettings::globalInstance()->device();
int n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
m_ctx.n_ctx = n_ctx;
int ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo);
std::string backend = "auto";
@@ -632,7 +632,6 @@ void ChatLLM::regenerateResponse()
else
m_ctx.n_past -= m_promptResponseTokens;
m_ctx.n_past = std::max(0, m_ctx.n_past);
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
m_promptResponseTokens = 0;
m_promptTokens = 0;
m_response = m_trimmedResponse = std::string();
@@ -1078,12 +1077,13 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
stream << responseLogits;
}
stream << m_ctx.n_past;
if (version >= 7) {
stream << m_ctx.n_ctx;
}
stream << quint64(m_ctx.tokens.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int));
saveState();
if (version >= 7) {
stream << m_stateContextLength;
}
stream << quint64(m_stateInputTokens.size());
stream.writeRawData(reinterpret_cast<const char *>(m_stateInputTokens.data()),
m_stateInputTokens.size() * sizeof(m_stateInputTokens[0]));
QByteArray compressed = qCompress(m_state);
stream << compressed;
#if defined(DEBUG)
@@ -1145,7 +1145,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
if (version >= 7) {
uint32_t n_ctx;
stream >> n_ctx;
if (!discardKV) m_ctx.n_ctx = n_ctx;
if (!discardKV) m_stateContextLength = n_ctx;
}
if (version < 9) {
@@ -1157,10 +1157,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
quint64 tokensSize;
stream >> tokensSize;
if (!discardKV) {
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
m_stateInputTokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char *>(m_stateInputTokens.data()), tokensSize * sizeof(m_stateInputTokens[0]));
} else {
stream.skipRawData(tokensSize * sizeof(int));
stream.skipRawData(tokensSize * sizeof(m_stateInputTokens[0]));
}
if (version >= 1) {
@@ -1202,13 +1202,16 @@ void ChatLLM::saveState()
#if defined(DEBUG)
qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())},
m_stateInputTokens);
if (!ok) {
// FIXME(jared): how badly does this situation break GPT4All?
qWarning() << "ChatLLM failed to save LLModel state";
m_state.clear();
m_state.squeeze();
m_stateContextLength = -1;
}
m_stateContextLength = m_llModelInfo.model->contextLength();
}
void ChatLLM::restoreState()
@@ -1235,13 +1238,22 @@ void ChatLLM::restoreState()
if (m_state.isEmpty())
return;
size_t bytesRead = m_llModelInfo.model->restoreState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
if (bytesRead) {
m_processedSystemPrompt = true;
m_pristineLoadedState = true;
} else {
qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
if (m_llModelInfo.model->contextLength() != m_stateContextLength) {
qWarning() << "restoring state from text because of n_ctx mismatch (state"
<< m_stateContextLength << "model" << m_llModelInfo.model->contextLength() << ")";
m_restoreStateFromText = true;
} else {
size_t bytesRead = m_llModelInfo.model->restoreState(
{reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())},
m_stateInputTokens
);
if (!bytesRead) {
qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
m_restoreStateFromText = true;
} else {
m_processedSystemPrompt = true;
m_pristineLoadedState = true;
}
}
// free local state copy unless unload is pending