[autoparallel] fix C version rotor inconsistency (#1691)

This commit is contained in:
Boyuan Yao
2022-10-12 15:21:58 +08:00
committed by GitHub
parent 363fc2861a
commit 31d2f03d27
3 changed files with 54 additions and 22 deletions

View File

@@ -10,6 +10,9 @@ from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Los
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.logging import get_dist_logger
# global vairable to indicate whether the solver is failed
SOLVER_FAILED = False
# this is the python compute table code from rotor
# https://gitlab.inria.fr/hiepacs/rotor
@@ -87,9 +90,17 @@ def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
opt, what = opt_table
sequence = Sequence(Function("Persistent", lmax - lmin, cmem))
if opt[cmem][lmin][lmax] == float("inf"):
raise ValueError("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
lmax=lmax,
cmem=cmem))
# using logger to annonce that the solver is failed
logger = get_dist_logger()
logger.info("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
lmax=lmax,
cmem=cmem))
# set global indicater SOLVER_FAILED to True
global SOLVER_FAILED
SOLVER_FAILED = True
return sequence
if lmin == lmax:
if lmin == chain.length:
sequence.insert(Loss())
@@ -406,9 +417,18 @@ def solver_rotor(gm: ColoGraphModule,
# found sequence
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
_annotate_from_sequence(sequence, node_list)
# if solver failed, we don't need to annotate the graph
if not SOLVER_FAILED:
_annotate_from_sequence(sequence, node_list)
# set __sequence__ attribute to GraphModule
setattr(gm, "__sequence__", sequence)
if SOLVER_FAILED:
setattr(gm, "__sequence__", None)
else:
setattr(gm, "__sequence__", sequence)
# set __opttable__ attribute to GraphModule
setattr(gm, "__opttable__", opt_table[0])
gm.recompile()
return gm

View File

@@ -94,13 +94,16 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
OPT(m, i, i) = INFINITY;
for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i) {
long maxCostFWD = 0;
for (long l = i + 1; l <= chain_length; ++l) {
long mmin = cw[l + 1] + cw[i + 1] + fwd_tmp[i];
if (l > i + 1) {
maxCostFWD = fmaxl(maxCostFWD, cw[l - 1] + cw[l] + fwd_tmp[l - 1]);
mmin = fmaxl(mmin, cw[l + 1] + maxCostFWD);
for (long d = 1; d <= chain_length; ++d) {
for (long i = 0; i <= chain_length - d; ++i) {
long idx = i + d;
long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i];
if (idx > i + 1) {
long maxCostFWD = 0;
for (long j = i + 1; j < idx; j++) {
maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]);
}
mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD);
}
if ((m >= mmin)) {
long bestLeaf = -1;
@@ -108,10 +111,10 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
double bestLeafCost = INFINITY;
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
/// i+1
for (long j = i + 1; j <= l; ++j) {
for (long j = i + 1; j <= idx; ++j) {
sumFw += fw[j - 1];
if (m >= cw[j]) {
double cost = sumFw + OPT(m - cw[j], j, l) + OPT(m, i, j - 1);
double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1);
if (cost < bestLeafCost) {
bestLeafCost = cost;
bestLeaf = j;
@@ -120,16 +123,16 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
}
double chainCost = INFINITY;
if (m >= cbw[i + 1])
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, l);
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx);
if (bestLeafCost <= chainCost) {
OPT(m, i, l) = bestLeafCost;
WHAT(m, i, l) = bestLeaf;
OPT(m, i, idx) = bestLeafCost;
WHAT(m, i, idx) = bestLeaf;
} else {
OPT(m, i, l) = chainCost;
WHAT(m, i, l) = -1;
OPT(m, i, idx) = chainCost;
WHAT(m, i, idx) = -1;
}
} else
OPT(m, i, l) = INFINITY;
OPT(m, i, idx) = INFINITY;
}
}