From 97ec9074e52a06c9bb08338653c7c7b89ff92d69 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 3 May 2023 11:58:26 -0400 Subject: [PATCH] Add reverse prompts for llama models. --- llmodel/llamamodel.cpp | 43 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/llmodel/llamamodel.cpp b/llmodel/llamamodel.cpp index 0da930b6..d89862f1 100644 --- a/llmodel/llamamodel.cpp +++ b/llmodel/llamamodel.cpp @@ -16,6 +16,7 @@ #include #include #include +#include struct LLamaPrivate { const std::string modelPath; @@ -144,6 +145,11 @@ void LLamaModel::prompt(const std::string &prompt, i = batch_end; } + std::string cachedResponse; + std::vector cachedTokens; + std::unordered_set reversePrompts + = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" }; + // predict next tokens int32_t totalPredictions = 0; for (int i = 0; i < promptCtx.n_predict; i++) { @@ -175,11 +181,40 @@ void LLamaModel::prompt(const std::string &prompt, if (id == llama_token_eos()) return; - if (promptCtx.tokens.size() == promptCtx.n_ctx) - promptCtx.tokens.erase(promptCtx.tokens.begin()); - promptCtx.tokens.push_back(id); - if (!responseCallback(id, llama_token_to_str(d_ptr->ctx, id))) + const std::string str = llama_token_to_str(d_ptr->ctx, id); + + // Check if the provided str is part of our reverse prompts + bool foundPartialReversePrompt = false; + const std::string completed = cachedResponse + str; + if (reversePrompts.find(completed) != reversePrompts.end()) { return; + } + + // Check if it partially matches our reverse prompts and if so, cache + for (auto s : reversePrompts) { + if (s.compare(0, completed.size(), completed) == 0) { + foundPartialReversePrompt = true; + cachedResponse = completed; + break; + } + } + + // Regardless the token gets added to our cache + cachedTokens.push_back(id); + + // Continue if we have found a partial match + if (foundPartialReversePrompt) + continue; + + // Empty the cache + for (auto t : cachedTokens) { + if (promptCtx.tokens.size() == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(t); + if (!responseCallback(t, llama_token_to_str(d_ptr->ctx, t))) + return; + } + cachedTokens.clear(); } }