diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index 20004aa4..0c4764b1 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -786,12 +786,14 @@ const std::vector &GPTJ::endTokens() const } const char *get_arch_name(gguf_context *ctx_gguf) { - std::string arch_name; const int kid = gguf_find_key(ctx_gguf, "general.architecture"); + if (kid == -1) + throw std::runtime_error("key not found in model: general.architecture"); + enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid); - if (ktype != GGUF_TYPE_STRING) { - throw std::runtime_error("ERROR: Can't get general architecture from gguf file."); - } + if (ktype != GGUF_TYPE_STRING) + throw std::runtime_error("key general.architecture has wrong type"); + return gguf_get_val_str(ctx_gguf, kid); } @@ -824,7 +826,11 @@ DLL_EXPORT char *get_file_arch(const char *fname) { char *arch = nullptr; if (ctx_gguf && gguf_get_version(ctx_gguf) <= 3) { - arch = strdup(get_arch_name(ctx_gguf)); + try { + arch = strdup(get_arch_name(ctx_gguf)); + } catch (const std::runtime_error &) { + // cannot read key -> return null + } } gguf_free(ctx_gguf); diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 7c66be9f..250766bb 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -105,12 +105,14 @@ static int llama_sample_top_p_top_k( } const char *get_arch_name(gguf_context *ctx_gguf) { - std::string arch_name; const int kid = gguf_find_key(ctx_gguf, "general.architecture"); + if (kid == -1) + throw std::runtime_error("key not found in model: general.architecture"); + enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid); - if (ktype != (GGUF_TYPE_STRING)) { - throw std::runtime_error("ERROR: Can't get general architecture from gguf file."); - } + if (ktype != GGUF_TYPE_STRING) + throw std::runtime_error("key general.architecture has wrong type"); + return gguf_get_val_str(ctx_gguf, kid); } @@ -136,13 +138,20 @@ static gguf_context *load_gguf(const char *fname) { } static int32_t get_arch_key_u32(std::string const &modelPath, std::string const &archKey) { + int32_t value = -1; + std::string arch; + auto * ctx = load_gguf(modelPath.c_str()); if (!ctx) - return -1; - std::string arch = get_arch_name(ctx); + goto cleanup; - int32_t value = -1; - if (ctx) { + try { + arch = get_arch_name(ctx); + } catch (const std::runtime_error &) { + goto cleanup; // cannot read key + } + + { auto key = arch + "." + archKey; int keyidx = gguf_find_key(ctx, key.c_str()); if (keyidx != -1) { @@ -152,6 +161,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const } } +cleanup: gguf_free(ctx); return value; } @@ -244,15 +254,26 @@ bool LLamaModel::isModelBlacklisted(const std::string &modelPath) const { } bool LLamaModel::isEmbeddingModel(const std::string &modelPath) const { + bool result = false; + std::string arch; + auto *ctx_gguf = load_gguf(modelPath.c_str()); if (!ctx_gguf) { std::cerr << __func__ << ": failed to load GGUF from " << modelPath << "\n"; - return false; + goto cleanup; } - std::string arch = get_arch_name(ctx_gguf); + try { + arch = get_arch_name(ctx_gguf); + } catch (const std::runtime_error &) { + goto cleanup; // cannot read key + } + + result = is_embedding_arch(arch); + +cleanup: gguf_free(ctx_gguf); - return is_embedding_arch(arch); + return result; } bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl) @@ -964,16 +985,26 @@ DLL_EXPORT const char *get_build_variant() { } DLL_EXPORT char *get_file_arch(const char *fname) { - auto *ctx = load_gguf(fname); char *arch = nullptr; - if (ctx) { - std::string archStr = get_arch_name(ctx); - if (is_embedding_arch(archStr) && gguf_find_key(ctx, (archStr + ".pooling_type").c_str()) < 0) { - // old bert.cpp embedding model - } else { - arch = strdup(archStr.c_str()); - } + std::string archStr; + + auto *ctx = load_gguf(fname); + if (!ctx) + goto cleanup; + + try { + archStr = get_arch_name(ctx); + } catch (const std::runtime_error &) { + goto cleanup; // cannot read key } + + if (is_embedding_arch(archStr) && gguf_find_key(ctx, (archStr + ".pooling_type").c_str()) < 0) { + // old bert.cpp embedding model + } else { + arch = strdup(archStr.c_str()); + } + +cleanup: gguf_free(ctx); return arch; }