[autoparallel] Add pofo sequence annotation (#1637)

* [autoparallel] annotate pofo sequence

* [autoparallel] remove unused print

* [autoparallel] fix some code
This commit is contained in:
Boyuan Yao 2022-09-24 01:52:57 +08:00 committed by GitHub
parent 04bbabeea8
commit f921733621
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 133 additions and 6 deletions

View File

@ -145,7 +145,7 @@ def _find_ckpt_regions(nodes: List[Node]):
def _find_offload_regions(nodes: List[Node]): def _find_offload_regions(nodes: List[Node]):
"""This function is to find the offload regions """This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the In pofo algorithm, during annotation, we will annotate the offload region with the
tuple in the form of (idx, offload_input, offload_bar). idx indicates the offload list in the form of [idx, offload_input, offload_bar]. idx indicates the offload
region's index, offload_input is a bool type indicates whether we need to offload region's index, offload_input is a bool type indicates whether we need to offload
the input, offload_bar is a bool type indicates whether we need to offload all the the input, offload_bar is a bool type indicates whether we need to offload all the
intermediate x_bars of this region. intermediate x_bars of this region.
@ -157,7 +157,7 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None current_region = None
for idx, node in enumerate(nodes): for idx, node in enumerate(nodes):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', False), tuple): if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), list):
act_offload_label = node.activation_offload act_offload_label = node.activation_offload
if current_region == None: if current_region == None:

View File

@ -97,6 +97,7 @@ class PofoSolver:
self.bandwidth = bandwidth self.bandwidth = bandwidth
self.disc_chain = copy.deepcopy(self.chain) self.disc_chain = copy.deepcopy(self.chain)
self.disc_chain._discretize(self.mem_unit)
self.rotor_table = _compute_table(self.disc_chain, mem_slots) self.rotor_table = _compute_table(self.disc_chain, mem_slots)
self._compute_pofo_table() self._compute_pofo_table()
@ -142,7 +143,7 @@ class PofoSolver:
return (max(compute, comm) + compute + comm) / 2 return (max(compute, comm) + compute + comm) / 2
def _rotor_estimated_bwd_sequence(self, i, j, m, delta): def _rotor_estimated_bwd_sequence(self, i, j, m, delta):
return _rec(self.disc_chain, i, j, math.floor(m - self.chain.cweight[i] / self.mem_unit), self.rotor_table) return _rec(self.disc_chain, i, j, math.floor((m - self.chain.cweight[i]) / self.mem_unit), self.rotor_table)
def _common_values_enable(self, state: Tuple): def _common_values_enable(self, state: Tuple):
@ -354,6 +355,129 @@ class PofoSolver:
return result return result
def _annotate_from_pofo_sequence(sequence: Sequence, node_list: List[List[Node]]):
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
# forward annotation
for op in fwd_list:
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", [ckpt_idx])
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", [ckpt_idx])
ckpt_idx += 1
ckpt_region = [op.index]
else:
if isinstance(op, ForwardCheck):
in_ckpt = True
ckpt_region.append(op.index)
# annotate the backward if there is any nested activation checkpoint
in_recompute = False
for op in bwd_list:
if in_recompute:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
in_recompute = False
else:
if not isinstance(op, Backward):
in_recompute = True
ckpt_idx = 0
ckpt_region = []
if isinstance(op, ForwardCheck):
ckpt_region.append(op.index)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list = []
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
for idx in range(start_idx, end_idx + 1):
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
# annotate the offload
offload_idx = 0
for idx, op in enumerate(fwd_list):
if isinstance(op, Offload):
# corner case: offload input
if op.index == 0:
if isinstance(fwd_list[idx + 1], ForwardCheck):
for n in node_list[op.index]:
setattr(n, "activation_offload", True)
else:
for n in node_list[op.index]:
setattr(n, "activation_offload", (offload_idx, True, False))
offload_idx += 1
else:
if op.has_bar:
# annotate previous node
if hasattr(node_list[op.index - 1][0], "activation_offload"):
for n in node_list[op.index - 1]:
n.activation_offload[-1] = True
else:
for n in node_list[op.index - 1]:
setattr(n, "activation_offload", [offload_idx, False, True])
offload_idx += 1
# annotate this node
if isinstance(fwd_list[idx + 1], ForwardCheck):
for n in node_list[op.index]:
setattr(n, "activation_offload", True)
else:
for n in node_list[op.index]:
setattr(n, "activation_offload", [offload_idx, True, False])
offload_idx += 1
def solver_pofo(gm: ColoGraphModule, def solver_pofo(gm: ColoGraphModule,
data, data,
bandwidth, bandwidth,
@ -398,7 +522,8 @@ def solver_pofo(gm: ColoGraphModule,
first_state = (0, 0, 0, 0, False) first_state = (0, 0, 0, 0, False)
sequence = solver.pofo_rec(first_state) sequence = solver.pofo_rec(first_state)
if sequence == None: if sequence == None:
print(f"Can not solve strategy with {mem_limit / 1024**2} MB memory!") raise ValueError(f"Cannot solve sequence with {mem_limit} Bytes memory")
_annotate_from_pofo_sequence(sequence, node_list)
setattr(gm, "__sequence__", sequence) setattr(gm, "__sequence__", sequence)
return gm return gm

View File

@ -54,7 +54,8 @@ class Offload(Operation):
super().__init__() super().__init__()
self.index = index self.index = index
self.name = "Off" self.name = "Off"
if has_bar: self.has_bar = has_bar
if self.has_bar:
self.name += "wBar" self.name += "wBar"
def __repr__(self): def __repr__(self):
@ -67,7 +68,8 @@ class Prefetch(Operation):
super().__init__() super().__init__()
self.index = index self.index = index
self.name = "Pre" self.name = "Pre"
if has_bar: self.has_bar = has_bar
if self.has_bar:
self.name += "wBar" self.name += "wBar"
def __repr__(self): def __repr__(self):