Backend prompt dedup (#822)

* Deduplicated prompt() function code
This commit is contained in:
AT
2023-06-04 08:59:24 -04:00
committed by GitHub
parent 945297d837
commit bbe195ee02
10 changed files with 286 additions and 457 deletions

View File

@@ -7,11 +7,14 @@
#include <string_view>
#include <fstream>
#include <cstdint>
#include <limits>
class Dlhandle;
class LLModel {
public:
using Token = int32_t;
class Implementation {
LLModel *(*construct_)();
@@ -60,11 +63,11 @@ public:
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }
virtual void prompt(const std::string &prompt,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &ctx) = 0;
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) = 0;
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &ctx);
virtual void setThreadCount(int32_t /*n_threads*/) {}
virtual int32_t threadCount() const { return 1; }
@@ -84,10 +87,20 @@ public:
}
protected:
const Implementation *m_implementation = nullptr;
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(const std::string&) const = 0;
virtual std::string_view tokenToString(Token) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0;
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;
virtual int32_t contextLength() const = 0;
virtual const std::vector<Token>& endTokens() const = 0;
// This is a helper function called from the default implementation of 'prompt' but it can be
// shared by all base classes so it isn't virtual
void recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate);
static std::string m_implementations_search_path;
const Implementation *m_implementation = nullptr;
static std::string m_implementations_search_path;
};
#endif // LLMODEL_H