From 9c66308922757bd9c36719e250df6d4d33875c35 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Mon, 8 May 2023 18:55:33 -0400 Subject: [PATCH] Fix for special im_end token in mpt-7b-chat model. --- llmodel/mpt.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/llmodel/mpt.cpp b/llmodel/mpt.cpp index 6e4c5761..d1fb0164 100644 --- a/llmodel/mpt.cpp +++ b/llmodel/mpt.cpp @@ -959,6 +959,7 @@ struct MPTPrivate { int64_t n_threads = 0; size_t mem_per_token = 0; std::mt19937 rng; + bool has_im_end = false; }; MPT::MPT() @@ -982,6 +983,7 @@ bool MPT::loadModel(const std::string &modelPath) { d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); d_ptr->modelLoaded = true; + d_ptr->has_im_end = d_ptr->vocab.token_to_id.find("<|im_end|>") != d_ptr->vocab.token_to_id.end(); fflush(stdout); return true; } @@ -1150,6 +1152,10 @@ void MPT::prompt(const std::string &prompt, // display text ++totalPredictions; + // mpt-7b-chat has special token for end + if (d_ptr->has_im_end && id == d_ptr->vocab.token_to_id["<|im_end|>"]) + goto stop_generating; + if (id == 0 /*end of text*/) goto stop_generating;