mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,14 +3,16 @@ import os
|
||||
from setuptools import Extension, setup
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
ext_modules = [Extension(
|
||||
'rotorc',
|
||||
sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')],
|
||||
)]
|
||||
ext_modules = [
|
||||
Extension(
|
||||
"rotorc",
|
||||
sources=[os.path.join(this_dir, "ckpt_solver_rotor.c")],
|
||||
)
|
||||
]
|
||||
|
||||
setup(
|
||||
name='rotor c extension',
|
||||
version='0.1',
|
||||
description='rotor c extension for faster dp computing',
|
||||
name="rotor c extension",
|
||||
version="0.1",
|
||||
description="rotor c extension for faster dp computing",
|
||||
ext_modules=ext_modules,
|
||||
)
|
||||
|
@@ -12,13 +12,13 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import (
|
||||
)
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||
|
||||
__all___ = ['CheckpointSolverBase']
|
||||
__all___ = ["CheckpointSolverBase"]
|
||||
|
||||
|
||||
def _copy_output(src: Graph, dst: Graph):
|
||||
"""Copy the output node from src to dst"""
|
||||
for n_src, n_dst in zip(src.nodes, dst.nodes):
|
||||
if n_src.op == 'output':
|
||||
if n_src.op == "output":
|
||||
n_dst.meta = n_src.meta
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module):
|
||||
|
||||
|
||||
class CheckpointSolverBase(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
@@ -81,13 +80,10 @@ class CheckpointSolverBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def solve(self):
|
||||
"""Solve the checkpointing problem and return the solution.
|
||||
"""
|
||||
pass
|
||||
"""Solve the checkpointing problem and return the solution."""
|
||||
|
||||
def get_node_list(self):
|
||||
"""Get the node list.
|
||||
"""
|
||||
"""Get the node list."""
|
||||
return [[node] for node in self.graph.nodes]
|
||||
|
||||
def _linearize_graph(self) -> List[List[Node]]:
|
||||
@@ -140,8 +136,7 @@ class CheckpointSolverBase(ABC):
|
||||
"""
|
||||
|
||||
def _is_inplace(n: Node):
|
||||
"""Get the inplace argument from ``torch.fx.Node``
|
||||
"""
|
||||
"""Get the inplace argument from ``torch.fx.Node``"""
|
||||
inplace = False
|
||||
if n.op == "call_function":
|
||||
inplace = n.kwargs.get("inplace", False)
|
||||
@@ -150,19 +145,22 @@ class CheckpointSolverBase(ABC):
|
||||
return inplace
|
||||
|
||||
def _is_shape_consistency(n: Node):
|
||||
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
|
||||
"""
|
||||
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)"""
|
||||
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
|
||||
|
||||
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
|
||||
map(_is_shape_consistency, n.users))
|
||||
return (
|
||||
not sum([v for _, v in deps.items()])
|
||||
and not any(map(_is_inplace, n.users))
|
||||
and not any(map(_is_shape_consistency, n.users))
|
||||
)
|
||||
|
||||
# make sure that item in cnode is valid
|
||||
if self.cnode:
|
||||
for name in self.cnode:
|
||||
try:
|
||||
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
|
||||
f"Common node {name} is not an input of the model."
|
||||
assert (
|
||||
next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
|
||||
), f"Common node {name} is not an input of the model."
|
||||
except StopIteration:
|
||||
raise ValueError(f"Common node name {name} not in graph.")
|
||||
|
||||
@@ -187,8 +185,9 @@ class CheckpointSolverBase(ABC):
|
||||
region = []
|
||||
|
||||
# propagate common node attr if possible
|
||||
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
|
||||
]) or _is_cop(n.target):
|
||||
if len(n.all_input_nodes) == len(
|
||||
[node for node in n.all_input_nodes if node.name in self.cnode]
|
||||
) or _is_cop(n.target):
|
||||
self.cnode.append(n.name)
|
||||
else:
|
||||
deps[n] = len([user for user in n.users if user.op != "output"])
|
||||
|
@@ -8,11 +8,10 @@ from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
|
||||
|
||||
from .ckpt_solver_base import CheckpointSolverBase
|
||||
|
||||
__all__ = ['CheckpointSolverChen']
|
||||
__all__ = ["CheckpointSolverChen"]
|
||||
|
||||
|
||||
class CheckpointSolverChen(CheckpointSolverBase):
|
||||
|
||||
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
|
||||
"""
|
||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||
@@ -40,14 +39,14 @@ class CheckpointSolverChen(CheckpointSolverBase):
|
||||
Returns:
|
||||
graph (Graph): The optimized graph, should be a copy of the original graph.
|
||||
"""
|
||||
checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
|
||||
checkpointable_op = ["call_module", "call_method", "call_function", "get_attr"]
|
||||
ckpt = self.grid_search()
|
||||
for i, seg in enumerate(ckpt):
|
||||
for idx in range(*seg):
|
||||
nodes = self.node_list[idx]
|
||||
for n in nodes:
|
||||
if n.op in checkpointable_op:
|
||||
n.meta['activation_checkpoint'] = i
|
||||
n.meta["activation_checkpoint"] = i
|
||||
return deepcopy(self.graph)
|
||||
|
||||
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from torch import Tensor
|
||||
from torch.fx import Graph, Node
|
||||
@@ -18,17 +18,18 @@ from colossalai.logging import get_dist_logger
|
||||
from .ckpt_solver_base import CheckpointSolverBase
|
||||
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
|
||||
|
||||
__all__ = ['CheckpointSolverRotor']
|
||||
__all__ = ["CheckpointSolverRotor"]
|
||||
|
||||
|
||||
class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
free_memory: float = -1,
|
||||
cnode: List[str] = None,
|
||||
memory_slots: int = 500,
|
||||
optim_multiplier: float = 1.0):
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
free_memory: float = -1,
|
||||
cnode: List[str] = None,
|
||||
memory_slots: int = 500,
|
||||
optim_multiplier: float = 1.0,
|
||||
):
|
||||
"""This is the simple implementation of dynamic programming algorithm rotor
|
||||
in https://hal.inria.fr/hal-02352969. Some code are adapted from
|
||||
https://gitlab.inria.fr/hiepacs/rotor.
|
||||
@@ -85,13 +86,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
|
||||
# backtrack
|
||||
try:
|
||||
self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
|
||||
self.back_ptr)
|
||||
self.sequence = self._backtrack(
|
||||
chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, self.back_ptr
|
||||
)
|
||||
self._annotate_from_sequence(self.sequence, self.node_list)
|
||||
except ValueError as e:
|
||||
# using logger to annonce that the solver is failed
|
||||
logger = get_dist_logger()
|
||||
logger.warning(f'Checkpoint solver failed: {e}')
|
||||
logger.warning(f"Checkpoint solver failed: {e}")
|
||||
raise ValueError
|
||||
|
||||
if verbose:
|
||||
@@ -100,14 +102,19 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
return deepcopy(self.graph)
|
||||
|
||||
def print_chain(self):
|
||||
print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
|
||||
print("[input]", self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
|
||||
for idx in range(len(self.node_list) - 1):
|
||||
print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
|
||||
self.chain.btmp[idx])
|
||||
print(f'Chain = {self.chain}')
|
||||
print(
|
||||
self.node_list[idx],
|
||||
self.chain.x[idx + 1],
|
||||
self.chain.xbar[idx + 1],
|
||||
self.chain.ftmp[idx],
|
||||
self.chain.btmp[idx],
|
||||
)
|
||||
print(f"Chain = {self.chain}")
|
||||
|
||||
def print_sequence(self):
|
||||
print(f'Sequence = {self.sequence}')
|
||||
print(f"Sequence = {self.sequence}")
|
||||
|
||||
@classmethod
|
||||
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
|
||||
@@ -138,14 +145,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
btime = 0
|
||||
fwd_mem_peak = 0
|
||||
for n in node:
|
||||
assert isinstance(n, Node), f'{n} is not a Node'
|
||||
assert isinstance(n, Node), f"{n} is not a Node"
|
||||
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
|
||||
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
|
||||
xbar += n.meta['fwd_mem_out']
|
||||
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
|
||||
xbar += n.meta["fwd_mem_out"]
|
||||
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"])
|
||||
else:
|
||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
|
||||
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"] + cls._extract_unused_output(n))
|
||||
|
||||
# minimum flop count is required
|
||||
ftime += max(calculate_fwd_time(n), 1.0)
|
||||
@@ -162,14 +169,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
"""Extract input tensors from a Graph"""
|
||||
input_tensors = []
|
||||
for node in graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
input_tensors.append(node.meta['fwd_out'])
|
||||
if node.op == "placeholder":
|
||||
input_tensors.append(node.meta["fwd_out"])
|
||||
return input_tensors
|
||||
|
||||
@staticmethod
|
||||
def _extract_unused_output(node: Node) -> int:
|
||||
"""Extract unused output from `torch.fx.Node`"""
|
||||
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
|
||||
return activation_size(node.meta["fwd_out"]) - calculate_fwd_out(node)
|
||||
|
||||
@staticmethod
|
||||
def _extract_btmp(node: List[Node]) -> int:
|
||||
@@ -180,8 +187,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
for k, v in deps.items():
|
||||
k: Node
|
||||
if v > 0:
|
||||
deps_size += k.meta['bwd_mem_out']
|
||||
if v == float('-inf'):
|
||||
deps_size += k.meta["bwd_mem_out"]
|
||||
if v == float("-inf"):
|
||||
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
|
||||
|
||||
return deps_size
|
||||
@@ -190,12 +197,12 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
deps = {}
|
||||
for n in reversed(node):
|
||||
deps[n] = len(n.all_input_nodes)
|
||||
btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
|
||||
btmp = max(btmp, _extract_deps_size() + n.meta["bwd_mem_tmp"])
|
||||
for child in n.users:
|
||||
if child in deps:
|
||||
deps[child] -= 1
|
||||
if deps[child] <= 0:
|
||||
deps[child] = float('-inf') # free
|
||||
deps[child] = float("-inf") # free
|
||||
return btmp
|
||||
|
||||
@staticmethod
|
||||
@@ -244,10 +251,11 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
if m < mmin:
|
||||
cost_table[m][i][idx] = float("inf")
|
||||
else:
|
||||
leaf_checkpoints = [(j,
|
||||
sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
|
||||
for j in range(i + 1, idx + 1)
|
||||
if m >= x[j]]
|
||||
leaf_checkpoints = [
|
||||
(j, sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
|
||||
for j in range(i + 1, idx + 1)
|
||||
if m >= x[j]
|
||||
]
|
||||
if leaf_checkpoints:
|
||||
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
|
||||
else:
|
||||
@@ -274,13 +282,16 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
result = subprocess.Popen(
|
||||
[
|
||||
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
|
||||
f"--build-lib={this_dir}"
|
||||
f"{sys.executable}",
|
||||
f"{os.path.join(this_dir, 'build_c_ext.py')}",
|
||||
"build_ext",
|
||||
f"--build-lib={this_dir}",
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
@@ -294,8 +305,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
return compute_table(chain, mmax)
|
||||
|
||||
@staticmethod
|
||||
def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
|
||||
back_ptr: List[Any]) -> "Sequence":
|
||||
def _backtrack(
|
||||
chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], back_ptr: List[Any]
|
||||
) -> "Sequence":
|
||||
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
|
||||
|
||||
Args:
|
||||
@@ -328,8 +340,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
if back_ptr[budget][lhs][rhs][0]:
|
||||
sequence += [
|
||||
ForwardEnable(lhs),
|
||||
CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
|
||||
back_ptr),
|
||||
CheckpointSolverRotor._backtrack(
|
||||
chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, back_ptr
|
||||
),
|
||||
Backward(lhs),
|
||||
]
|
||||
else:
|
||||
@@ -337,8 +350,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
sequence += [ForwardCheck(lhs)]
|
||||
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
|
||||
sequence += [
|
||||
CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
|
||||
back_ptr),
|
||||
CheckpointSolverRotor._backtrack(
|
||||
chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, back_ptr
|
||||
),
|
||||
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
|
||||
]
|
||||
return sequence
|
||||
@@ -353,8 +367,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
"""
|
||||
op_list = sequence.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
fwd_list = op_list[:op_list.index(loss_op)]
|
||||
bwd_list = op_list[op_list.index(loss_op) + 1:]
|
||||
fwd_list = op_list[: op_list.index(loss_op)]
|
||||
bwd_list = op_list[op_list.index(loss_op) + 1 :]
|
||||
ckpt_idx = 0
|
||||
in_ckpt = False
|
||||
ckpt_region = []
|
||||
@@ -369,7 +383,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
in_ckpt = False
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'] = [ckpt_idx]
|
||||
n.meta["activation_checkpoint"] = [ckpt_idx]
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
@@ -377,7 +391,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'] = [ckpt_idx]
|
||||
n.meta["activation_checkpoint"] = [ckpt_idx]
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [idx]
|
||||
@@ -397,7 +411,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
elif isinstance(op, ForwardEnable):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
||||
n.meta["activation_checkpoint"].append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
@@ -405,7 +419,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
||||
n.meta["activation_checkpoint"].append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [op.index]
|
||||
@@ -413,7 +427,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
elif isinstance(op, Backward):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.meta['activation_checkpoint'].append(ckpt_idx)
|
||||
n.meta["activation_checkpoint"].append(ckpt_idx)
|
||||
|
||||
in_recompute = False
|
||||
|
||||
@@ -431,9 +445,11 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
for node in node_list:
|
||||
op_list += node
|
||||
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
||||
for (start_idx, end_idx) in ckpt_regions:
|
||||
for start_idx, end_idx in ckpt_regions:
|
||||
nested_length = max(
|
||||
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
|
||||
len(op_list[idx].meta["activation_checkpoint"]) for idx in range(start_idx, end_idx + 1)
|
||||
)
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
|
||||
len(op_list[idx].meta['activation_checkpoint']))
|
||||
op_list[idx].meta["activation_checkpoint"] += [None] * (
|
||||
nested_length - len(op_list[idx].meta["activation_checkpoint"])
|
||||
)
|
||||
|
@@ -1,20 +1,21 @@
|
||||
import math
|
||||
from abc import ABC
|
||||
from typing import Any, Iterable, List
|
||||
from typing import List
|
||||
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
class Chain:
|
||||
|
||||
def __init__(self,
|
||||
ftime: List[float],
|
||||
btime: List[float],
|
||||
x: List[int],
|
||||
xbar: List[int],
|
||||
ftmp: List[int],
|
||||
btmp: List[int],
|
||||
check_consistency: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
ftime: List[float],
|
||||
btime: List[float],
|
||||
x: List[int],
|
||||
xbar: List[int],
|
||||
ftmp: List[int],
|
||||
btmp: List[int],
|
||||
check_consistency: bool = True,
|
||||
):
|
||||
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
|
||||
See paper https://hal.inria.fr/hal-02352969 for details.
|
||||
|
||||
@@ -37,9 +38,14 @@ class Chain:
|
||||
raise AttributeError("In Chain, input lists do not have consistent lengths")
|
||||
|
||||
def check_lengths(self):
|
||||
return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
|
||||
and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
|
||||
and (len(self.xbar) == len(self) + 1))
|
||||
return (
|
||||
(len(self.ftime) == len(self))
|
||||
and (len(self.btime) == len(self) + 1)
|
||||
and (len(self.x) == len(self) + 1)
|
||||
and (len(self.ftmp) == len(self))
|
||||
and (len(self.btmp) == len(self) + 1)
|
||||
and (len(self.xbar) == len(self) + 1)
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
chain_list = []
|
||||
@@ -100,7 +106,6 @@ class ForwardCheck(Forward):
|
||||
|
||||
|
||||
class Forwards(Operation):
|
||||
|
||||
def __init__(self, start, end):
|
||||
self.index = (start, end)
|
||||
|
||||
@@ -109,9 +114,9 @@ class Forwards(Operation):
|
||||
|
||||
def cost(self, chain: Chain):
|
||||
if chain is not None:
|
||||
return sum(chain.ftime[self.index[0]:self.index[1] + 1])
|
||||
return sum(chain.ftime[self.index[0] : self.index[1] + 1])
|
||||
else:
|
||||
return (self.index[1] - self.index[0] + 1)
|
||||
return self.index[1] - self.index[0] + 1
|
||||
|
||||
|
||||
def isForward(op):
|
||||
@@ -132,7 +137,6 @@ class Backward(Operation):
|
||||
|
||||
|
||||
class Loss(Operation):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess):
|
||||
|
||||
|
||||
class Sequence(list):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
Reference in New Issue
Block a user