[pipeline/rank_recorder] fix bug when process data before backward | add a tool for multiple ranks debug (#1681)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera

* [pipeline/chimera] test chimera | fix bug of initializing

* [pipeline/pytree] add pytree to process args and kwargs | provide  to process args and kwargs after forward
This commit is contained in:
Kirigaya Kazuto
2022-10-09 17:32:57 +08:00
committed by GitHub
parent 517b63939a
commit 3b2a59b0ba
6 changed files with 283 additions and 13 deletions

View File

@@ -5,6 +5,7 @@ from functools import partial
from abc import ABC, abstractmethod
import sys
import os
import time
import inspect
import torch
@@ -12,12 +13,13 @@ from torch import nn
import torch.distributed.rpc as rpc
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from torch import autograd
from torch import optim
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
pytree_map, get_real_args_kwargs, use_color_debug)
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug)
class Phase(Enum):
@@ -469,6 +471,7 @@ class WorkerBase(ABC):
else:
consume_result = self.module_partition(*args, **kwargs)
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
if is_last_stage and self.criterion:
@@ -495,7 +498,6 @@ class WorkerBase(ABC):
stage_input_kwargs,
stage_outputs,
checkpoint=use_checkpoint)
# if not forward_only, do the backward
if not forward_only:
if is_last_stage: # if it is the last stage, trigger backward automatic
@@ -521,19 +523,19 @@ class WorkerBase(ABC):
if use_checkpoint:
stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)]
# take tensor only (for only tensor can do backward)
stage_outputs_tensors = []
pytree_map(stage_outputs, stage_outputs_tensors.append, process_types=torch.Tensor)
# overlap recompute and future.wait
grad_tensors = get_real_args_kwargs(args)
if not is_last_stage:
grad_tensors = get_real_args_kwargs(args)
else:
grad_tensors = None
# print('rank', self.pp_rank, tensor_shape_list(stage_outputs_tensors), tensor_shape_list(grad_tensors))
autograd.backward(stage_outputs_tensors, grad_tensors=grad_tensors)
# take tensor only (for only tensor can do backward)
stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor)
grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor)
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# collect grad of input tensor
# there is a hypothesis that node in kwargs cann't be an non-leaf node in graph
# so we don't need to save the grad of node in kwargs.
consume_result = []
if not is_first_stage:
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)

View File

@@ -110,7 +110,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
use_1F1B = True
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,

View File

@@ -20,7 +20,8 @@ def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] =
Args:
obj (:class:`Any`): object to process
fn (:class:`Callable`): a function to process subobject in obj
process_types(:class: `type | tuple[type]`): types to determine the type to process
process_types (:class: `type | tuple[type]`): types to determine the type to process
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
Returns:
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
@@ -59,6 +60,20 @@ def type_detail(obj):
return pytree_map(obj, lambda x: type(x), map_all=True)
def pytree_filter(fn, obj, process_types):
if obj is None:
return None
filters = []
def condition_append(obj):
if fn(obj):
filters.append(obj)
pytree_map(obj, fn=condition_append, process_types=process_types)
return filters
def get_real_args_kwargs(args_or_kwargs):
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
# TODO : combine producer and consumer