mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[autoparallel] fix C version rotor inconsistency (#1691)
This commit is contained in:
@@ -10,6 +10,9 @@ from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Los
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
# global vairable to indicate whether the solver is failed
|
||||
SOLVER_FAILED = False
|
||||
|
||||
|
||||
# this is the python compute table code from rotor
|
||||
# https://gitlab.inria.fr/hiepacs/rotor
|
||||
@@ -87,9 +90,17 @@ def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
|
||||
opt, what = opt_table
|
||||
sequence = Sequence(Function("Persistent", lmax - lmin, cmem))
|
||||
if opt[cmem][lmin][lmax] == float("inf"):
|
||||
raise ValueError("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
|
||||
lmax=lmax,
|
||||
cmem=cmem))
|
||||
# using logger to annonce that the solver is failed
|
||||
logger = get_dist_logger()
|
||||
logger.info("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
|
||||
lmax=lmax,
|
||||
cmem=cmem))
|
||||
|
||||
# set global indicater SOLVER_FAILED to True
|
||||
global SOLVER_FAILED
|
||||
SOLVER_FAILED = True
|
||||
return sequence
|
||||
|
||||
if lmin == lmax:
|
||||
if lmin == chain.length:
|
||||
sequence.insert(Loss())
|
||||
@@ -406,9 +417,18 @@ def solver_rotor(gm: ColoGraphModule,
|
||||
|
||||
# found sequence
|
||||
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
|
||||
_annotate_from_sequence(sequence, node_list)
|
||||
|
||||
# if solver failed, we don't need to annotate the graph
|
||||
if not SOLVER_FAILED:
|
||||
_annotate_from_sequence(sequence, node_list)
|
||||
|
||||
# set __sequence__ attribute to GraphModule
|
||||
setattr(gm, "__sequence__", sequence)
|
||||
if SOLVER_FAILED:
|
||||
setattr(gm, "__sequence__", None)
|
||||
else:
|
||||
setattr(gm, "__sequence__", sequence)
|
||||
|
||||
# set __opttable__ attribute to GraphModule
|
||||
setattr(gm, "__opttable__", opt_table[0])
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
@@ -94,13 +94,16 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
|
||||
OPT(m, i, i) = INFINITY;
|
||||
|
||||
for (long m = 0; m <= mmax; ++m)
|
||||
for (long i = 0; i <= chain_length; ++i) {
|
||||
long maxCostFWD = 0;
|
||||
for (long l = i + 1; l <= chain_length; ++l) {
|
||||
long mmin = cw[l + 1] + cw[i + 1] + fwd_tmp[i];
|
||||
if (l > i + 1) {
|
||||
maxCostFWD = fmaxl(maxCostFWD, cw[l - 1] + cw[l] + fwd_tmp[l - 1]);
|
||||
mmin = fmaxl(mmin, cw[l + 1] + maxCostFWD);
|
||||
for (long d = 1; d <= chain_length; ++d) {
|
||||
for (long i = 0; i <= chain_length - d; ++i) {
|
||||
long idx = i + d;
|
||||
long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i];
|
||||
if (idx > i + 1) {
|
||||
long maxCostFWD = 0;
|
||||
for (long j = i + 1; j < idx; j++) {
|
||||
maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]);
|
||||
}
|
||||
mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD);
|
||||
}
|
||||
if ((m >= mmin)) {
|
||||
long bestLeaf = -1;
|
||||
@@ -108,10 +111,10 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
|
||||
double bestLeafCost = INFINITY;
|
||||
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
|
||||
/// i+1
|
||||
for (long j = i + 1; j <= l; ++j) {
|
||||
for (long j = i + 1; j <= idx; ++j) {
|
||||
sumFw += fw[j - 1];
|
||||
if (m >= cw[j]) {
|
||||
double cost = sumFw + OPT(m - cw[j], j, l) + OPT(m, i, j - 1);
|
||||
double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1);
|
||||
if (cost < bestLeafCost) {
|
||||
bestLeafCost = cost;
|
||||
bestLeaf = j;
|
||||
@@ -120,16 +123,16 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
|
||||
}
|
||||
double chainCost = INFINITY;
|
||||
if (m >= cbw[i + 1])
|
||||
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, l);
|
||||
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx);
|
||||
if (bestLeafCost <= chainCost) {
|
||||
OPT(m, i, l) = bestLeafCost;
|
||||
WHAT(m, i, l) = bestLeaf;
|
||||
OPT(m, i, idx) = bestLeafCost;
|
||||
WHAT(m, i, idx) = bestLeaf;
|
||||
} else {
|
||||
OPT(m, i, l) = chainCost;
|
||||
WHAT(m, i, l) = -1;
|
||||
OPT(m, i, idx) = chainCost;
|
||||
WHAT(m, i, idx) = -1;
|
||||
}
|
||||
} else
|
||||
OPT(m, i, l) = INFINITY;
|
||||
OPT(m, i, idx) = INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user