From 8593ae1a3fada934da66cf680b93c73cff718139 Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Sat, 18 Feb 2023 11:30:15 +0800 Subject: [PATCH] [autoparallel] rotor solver refactor (#2813) * [autoparallel] rotor solver refactor * [autoparallel] rotor solver refactor --- .../checkpoint/ckpt_solver_rotor.c | 28 +++++++++++++------ .../checkpoint/ckpt_solver_rotor.py | 16 +++++------ 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c index 0fdcfd58a..8dad074bc 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c @@ -1,6 +1,12 @@ #define PY_SSIZE_T_CLEAN #include +/* +Rotor solver for checkpointing problem in C. We follow the modeling mentioned in +paper `Optimal checkpointing for heterogeneous chains: how to train deep neural +networks with limited memory` https://hal.inria.fr/hal-02352969. Some lines of +the code are adapted from https://gitlab.inria.fr/hiepacs/rotor. +*/ long* PySequenceToLongArray(PyObject* pylist) { if (!(pylist && PySequence_Check(pylist))) return NULL; Py_ssize_t len = PySequence_Size(pylist); @@ -81,14 +87,16 @@ static PyObject* computeTable(PyObject* self, PyObject* args) { (mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long)); for (long m = 0; m <= mmax; ++m) - for (long i = 0; i <= chainLength; ++i) + for (long i = 0; i <= chainLength; ++i) { if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) && - (m >= x[i + 1] + xbar[i + 1] + ftmp[i])) + (m >= x[i + 1] + xbar[i + 1] + ftmp[i])) { COST_TABLE(m, i, i) = ftime[i] + btime[i]; - else + } else { COST_TABLE(m, i, i) = INFINITY; + } + } - for (long m = 0; m <= mmax; ++m) + for (long m = 0; m <= mmax; ++m) { for (long d = 1; d <= chainLength; ++d) { for (long i = 0; i <= chainLength - d; ++i) { long idx = i + d; @@ -116,9 +124,10 @@ static PyObject* computeTable(PyObject* self, PyObject* args) { } } double chainCost = INFINITY; - if (m >= xbar[i + 1]) + if (m >= xbar[i + 1]) { chainCost = COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx); + } if (bestLeafCost <= chainCost) { COST_TABLE(m, i, idx) = bestLeafCost; BACK_PTR(m, i, idx) = bestLeaf; @@ -126,10 +135,12 @@ static PyObject* computeTable(PyObject* self, PyObject* args) { COST_TABLE(m, i, idx) = chainCost; BACK_PTR(m, i, idx) = -1; } - } else + } else { COST_TABLE(m, i, idx) = INFINITY; + } } } + } free(ftime); free(btime); @@ -158,10 +169,11 @@ static PyObject* computeTable(PyObject* self, PyObject* args) { PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l); Py_DECREF(pyCostTable_m_i_l); PyObject* pyBackPtr_m_i_l; - if (BACK_PTR(m, i, l) < 0) + if (BACK_PTR(m, i, l) < 0) { pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True); - else + } else { pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l)); + } PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l); Py_DECREF(pyBackPtr_m_i_l); Py_DECREF(pyVar_l); diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py index 41d23be5c..21c3bf0da 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -207,11 +207,10 @@ class CheckpointSolverRotor(CheckpointSolverBase): mmax (int): Maximum number of memory slots. Returns: - cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length - and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax - back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice - is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint - of length j + cost_table (List): cost_table[m][lhs][rhs] indicates the optimal cost of the subproblem from lhs to rhs + with m memory slots. + back_ptr (List): back_ptr[m][lhs][rhs] indicates the best operation at this point. It is (True,) if the optimal choice + is a chain checkpoint, it is (False, j) if the optimal choice is a leaf checkpoint of length j """ ftime = chain.ftime + [0.0] @@ -224,18 +223,17 @@ class CheckpointSolverRotor(CheckpointSolverBase): # Build table cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)] - # Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation - # Initialize borders of the tables for lmax-lmin = 0 + # Initialize corner cases where length of sequence equals to 1, i.e. lhs == rhs for m in range(mmax + 1): for i in range(len(chain) + 1): limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i]) - if m >= limit: # Equation (1) + if m >= limit: cost_table[m][i][i] = ftime[i] + btime[i] else: cost_table[m][i][i] = float("inf") - # Compute everything + # Compute tables for m in range(mmax + 1): for d in range(1, len(chain) + 1): for i in range(len(chain) + 1 - d):