diff --git a/llmodel/mpt.cpp b/llmodel/mpt.cpp index 6e4c5761..d1fb0164 100644 --- a/llmodel/mpt.cpp +++ b/llmodel/mpt.cpp @@ -959,6 +959,7 @@ struct MPTPrivate { int64_t n_threads = 0; size_t mem_per_token = 0; std::mt19937 rng; + bool has_im_end = false; }; MPT::MPT() @@ -982,6 +983,7 @@ bool MPT::loadModel(const std::string &modelPath) { d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); d_ptr->modelLoaded = true; + d_ptr->has_im_end = d_ptr->vocab.token_to_id.find("<|im_end|>") != d_ptr->vocab.token_to_id.end(); fflush(stdout); return true; } @@ -1150,6 +1152,10 @@ void MPT::prompt(const std::string &prompt, // display text ++totalPredictions; + // mpt-7b-chat has special token for end + if (d_ptr->has_im_end && id == d_ptr->vocab.token_to_id["<|im_end|>"]) + goto stop_generating; + if (id == 0 /*end of text*/) goto stop_generating;