Complete revamp of model loading to allow for more discreet control by

the user of the models loading behavior.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat
2024-02-07 09:37:59 -05:00
committed by AT
parent f2024a1f9e
commit d948a4f2ee
14 changed files with 506 additions and 175 deletions

View File

@@ -180,6 +180,9 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
d_ptr->model_params.use_mlock = params.use_mlock;
#endif
d_ptr->model_params.progress_callback = &LLModel::staticProgressCallback;
d_ptr->model_params.progress_callback_user_data = this;
#ifdef GGML_USE_METAL
if (llama_verbose()) {
std::cerr << "llama.cpp: using Metal" << std::endl;

View File

@@ -74,6 +74,8 @@ public:
int32_t n_last_batch_tokens = 0;
};
using ProgressCallback = std::function<bool(float progress)>;
explicit LLModel() {}
virtual ~LLModel() {}
@@ -125,6 +127,8 @@ public:
virtual bool hasGPUDevice() { return false; }
virtual bool usingGPUDevice() { return false; }
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }
protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
@@ -153,6 +157,15 @@ protected:
const Implementation *m_implementation = nullptr;
ProgressCallback m_progressCallback;
static bool staticProgressCallback(float progress, void* ctx)
{
LLModel* model = static_cast<LLModel*>(ctx);
if (model && model->m_progressCallback)
return model->m_progressCallback(progress);
return true;
}
private:
friend class LLMImplementation;
};