Remove support for GPT-J models. (#2676)

Signed-off-by: Adam Treat <treat.adam@gmail.com>
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
AT
2024-07-17 16:07:37 -04:00
committed by GitHub
parent e2ebd1ff04
commit ca72428783
14 changed files with 5 additions and 1238 deletions

View File

@@ -37,7 +37,6 @@ using namespace Qt::Literals::StringLiterals;
//#define DEBUG
//#define DEBUG_MODEL_LOADING
#define GPTJ_INTERNAL_STATE_VERSION 0
#define LLAMA_INTERNAL_STATE_VERSION 0
class LLModelStore {
@@ -550,7 +549,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
switch (m_llModelInfo.model->implementation().modelType()[0]) {
case 'L': m_llModelType = LLModelType::LLAMA_; break;
case 'G': m_llModelType = LLModelType::GPTJ_; break;
default:
{
m_llModelInfo.resetModel(this);
@@ -1057,7 +1055,6 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
if (version > 1) {
stream << m_llModelType;
switch (m_llModelType) {
case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break;
case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break;
default: Q_UNREACHABLE();
}
@@ -1081,8 +1078,6 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
if (version >= 7) {
stream << m_ctx.n_ctx;
}
stream << quint64(m_ctx.logits.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float));
stream << quint64(m_ctx.tokens.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int));
saveState();
@@ -1139,12 +1134,9 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
if (!discardKV) m_ctx.n_ctx = n_ctx;
}
quint64 logitsSize;
stream >> logitsSize;
if (!discardKV) {
m_ctx.logits.resize(logitsSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
} else {
if (version < 9) {
quint64 logitsSize;
stream >> logitsSize;
stream.skipRawData(logitsSize * sizeof(float));
}