From 31fffd3fc5f96513115b1a0cb08461c1a23b3c25 Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Fri, 26 Aug 2022 15:47:08 +0800 Subject: [PATCH] [fx] fix wrong variable name in solver rotor (#1502) * [fx] fix wrong variable name in solver rotor * [fx] fix wrong variable name in solver rotor * code modification --- .../fx/passes/algorithms/ckpt_solver_rotor.py | 6 +++--- colossalai/fx/passes/algorithms/utils.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 396cf7b29..f7de4987c 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -73,7 +73,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple: return (opt, what) -def _rec(chain, lmin, lmax, cmem, opt_table): +def _rec(chain: Chain, lmin, lmax, cmem, opt_table): """ chain : the class describing the AC graph lmin : index of the first forward to execute lmax : upper bound index of the last forward to execute (not included) @@ -97,14 +97,14 @@ def _rec(chain, lmin, lmax, cmem, opt_table): if what[cmem][lmin][lmax][0]: sequence.insert(ForwardEnable(lmin)) - sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweigth[lmin + 1], opt_table)) + sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweight[lmin + 1], opt_table)) sequence.insert(Backward(lmin)) else: j = what[cmem][lmin][lmax][1] sequence.insert(ForwardCheck(lmin)) for k in range(lmin + 1, j): sequence.insert(ForwardNograd(k)) - sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweigth[j], opt_table)) + sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweight[j], opt_table)) sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table)) return sequence diff --git a/colossalai/fx/passes/algorithms/utils.py b/colossalai/fx/passes/algorithms/utils.py index 88efe0a0c..d26f1a2e2 100644 --- a/colossalai/fx/passes/algorithms/utils.py +++ b/colossalai/fx/passes/algorithms/utils.py @@ -44,9 +44,9 @@ class Forward(Operation): def __repr__(self): return "{n}_{i}".format(n=self.name, i=self.index) - def cost(self, chain): + def cost(self, chain: Chain): if chain is not None: - return chain.fweigth[self.index] + return chain.fweight[self.index] else: return 1 @@ -80,9 +80,9 @@ class Forwards(Operation): def __repr__(self): return "F_{i}->{j}".format(i=self.index[0], j=self.index[1]) - def cost(self, chain): + def cost(self, chain: Chain): if chain is not None: - return sum(chain.fweigth[self.index[0]:self.index[1] + 1]) + return sum(chain.fweight[self.index[0]:self.index[1] + 1]) else: return (self.index[1] - self.index[0] + 1) @@ -99,9 +99,9 @@ class Backward(Operation): def __repr__(self): return "B_{i}".format(i=self.index) - def cost(self, chain): + def cost(self, chain: Chain): if chain is not None: - return chain.bweigth[self.index] + return chain.bweight[self.index] else: return 1 @@ -126,7 +126,7 @@ class MemoryAccess(Operation): def __repr__(self): return "{n}_{i}".format(n=self.name, i=self.index) - def cost(self, chain): + def cost(self, chain: Chain): return 0