backend: do not crash if GGUF lacks general.architecture (#2346)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-05-15 13:57:13 -04:00 committed by GitHub
parent 6d8888b267
commit 9f9d8e636f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 24 deletions

View File

@ -786,12 +786,14 @@ const std::vector<LLModel::Token> &GPTJ::endTokens() const
} }
const char *get_arch_name(gguf_context *ctx_gguf) { const char *get_arch_name(gguf_context *ctx_gguf) {
std::string arch_name;
const int kid = gguf_find_key(ctx_gguf, "general.architecture"); const int kid = gguf_find_key(ctx_gguf, "general.architecture");
if (kid == -1)
throw std::runtime_error("key not found in model: general.architecture");
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid); enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
if (ktype != GGUF_TYPE_STRING) { if (ktype != GGUF_TYPE_STRING)
throw std::runtime_error("ERROR: Can't get general architecture from gguf file."); throw std::runtime_error("key general.architecture has wrong type");
}
return gguf_get_val_str(ctx_gguf, kid); return gguf_get_val_str(ctx_gguf, kid);
} }
@ -824,7 +826,11 @@ DLL_EXPORT char *get_file_arch(const char *fname) {
char *arch = nullptr; char *arch = nullptr;
if (ctx_gguf && gguf_get_version(ctx_gguf) <= 3) { if (ctx_gguf && gguf_get_version(ctx_gguf) <= 3) {
arch = strdup(get_arch_name(ctx_gguf)); try {
arch = strdup(get_arch_name(ctx_gguf));
} catch (const std::runtime_error &) {
// cannot read key -> return null
}
} }
gguf_free(ctx_gguf); gguf_free(ctx_gguf);

View File

@ -105,12 +105,14 @@ static int llama_sample_top_p_top_k(
} }
const char *get_arch_name(gguf_context *ctx_gguf) { const char *get_arch_name(gguf_context *ctx_gguf) {
std::string arch_name;
const int kid = gguf_find_key(ctx_gguf, "general.architecture"); const int kid = gguf_find_key(ctx_gguf, "general.architecture");
if (kid == -1)
throw std::runtime_error("key not found in model: general.architecture");
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid); enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
if (ktype != (GGUF_TYPE_STRING)) { if (ktype != GGUF_TYPE_STRING)
throw std::runtime_error("ERROR: Can't get general architecture from gguf file."); throw std::runtime_error("key general.architecture has wrong type");
}
return gguf_get_val_str(ctx_gguf, kid); return gguf_get_val_str(ctx_gguf, kid);
} }
@ -136,13 +138,20 @@ static gguf_context *load_gguf(const char *fname) {
} }
static int32_t get_arch_key_u32(std::string const &modelPath, std::string const &archKey) { static int32_t get_arch_key_u32(std::string const &modelPath, std::string const &archKey) {
int32_t value = -1;
std::string arch;
auto * ctx = load_gguf(modelPath.c_str()); auto * ctx = load_gguf(modelPath.c_str());
if (!ctx) if (!ctx)
return -1; goto cleanup;
std::string arch = get_arch_name(ctx);
int32_t value = -1; try {
if (ctx) { arch = get_arch_name(ctx);
} catch (const std::runtime_error &) {
goto cleanup; // cannot read key
}
{
auto key = arch + "." + archKey; auto key = arch + "." + archKey;
int keyidx = gguf_find_key(ctx, key.c_str()); int keyidx = gguf_find_key(ctx, key.c_str());
if (keyidx != -1) { if (keyidx != -1) {
@ -152,6 +161,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const
} }
} }
cleanup:
gguf_free(ctx); gguf_free(ctx);
return value; return value;
} }
@ -244,15 +254,26 @@ bool LLamaModel::isModelBlacklisted(const std::string &modelPath) const {
} }
bool LLamaModel::isEmbeddingModel(const std::string &modelPath) const { bool LLamaModel::isEmbeddingModel(const std::string &modelPath) const {
bool result = false;
std::string arch;
auto *ctx_gguf = load_gguf(modelPath.c_str()); auto *ctx_gguf = load_gguf(modelPath.c_str());
if (!ctx_gguf) { if (!ctx_gguf) {
std::cerr << __func__ << ": failed to load GGUF from " << modelPath << "\n"; std::cerr << __func__ << ": failed to load GGUF from " << modelPath << "\n";
return false; goto cleanup;
} }
std::string arch = get_arch_name(ctx_gguf); try {
arch = get_arch_name(ctx_gguf);
} catch (const std::runtime_error &) {
goto cleanup; // cannot read key
}
result = is_embedding_arch(arch);
cleanup:
gguf_free(ctx_gguf); gguf_free(ctx_gguf);
return is_embedding_arch(arch); return result;
} }
bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl) bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
@ -964,16 +985,26 @@ DLL_EXPORT const char *get_build_variant() {
} }
DLL_EXPORT char *get_file_arch(const char *fname) { DLL_EXPORT char *get_file_arch(const char *fname) {
auto *ctx = load_gguf(fname);
char *arch = nullptr; char *arch = nullptr;
if (ctx) { std::string archStr;
std::string archStr = get_arch_name(ctx);
if (is_embedding_arch(archStr) && gguf_find_key(ctx, (archStr + ".pooling_type").c_str()) < 0) { auto *ctx = load_gguf(fname);
// old bert.cpp embedding model if (!ctx)
} else { goto cleanup;
arch = strdup(archStr.c_str());
} try {
archStr = get_arch_name(ctx);
} catch (const std::runtime_error &) {
goto cleanup; // cannot read key
} }
if (is_embedding_arch(archStr) && gguf_find_key(ctx, (archStr + ".pooling_type").c_str()) < 0) {
// old bert.cpp embedding model
} else {
arch = strdup(archStr.c_str());
}
cleanup:
gguf_free(ctx); gguf_free(ctx);
return arch; return arch;
} }