Complete revamp of model loading to allow for more discreet control by

the user of the models loading behavior.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat
2024-02-07 09:37:59 -05:00
committed by AT
parent f2024a1f9e
commit d948a4f2ee
14 changed files with 506 additions and 175 deletions

View File

@@ -74,6 +74,8 @@ public:
int32_t n_last_batch_tokens = 0;
};
using ProgressCallback = std::function<bool(float progress)>;
explicit LLModel() {}
virtual ~LLModel() {}
@@ -125,6 +127,8 @@ public:
virtual bool hasGPUDevice() { return false; }
virtual bool usingGPUDevice() { return false; }
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }
protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
@@ -153,6 +157,15 @@ protected:
const Implementation *m_implementation = nullptr;
ProgressCallback m_progressCallback;
static bool staticProgressCallback(float progress, void* ctx)
{
LLModel* model = static_cast<LLModel*>(ctx);
if (model && model->m_progressCallback)
return model->m_progressCallback(progress);
return true;
}
private:
friend class LLMImplementation;
};