mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 17:40:33 +00:00
[pipeline/rank_recorder] fix bug when process data before backward | add a tool for multiple ranks debug (#1681)
* [pipeline/tuning] improve dispatch performance both time and space cost * [pipeline/converge] add interface for testing convergence * [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style * Update PipelineBase.py * [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera * [pipeline/chimera] test chimera | fix bug of initializing * [pipeline/pytree] add pytree to process args and kwargs | provide to process args and kwargs after forward
This commit is contained in:
@@ -5,6 +5,7 @@ from functools import partial
|
||||
from abc import ABC, abstractmethod
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
@@ -12,12 +13,13 @@ from torch import nn
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch.futures import Future
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
|
||||
from torch import autograd
|
||||
from torch import optim
|
||||
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
|
||||
pytree_map, get_real_args_kwargs, use_color_debug)
|
||||
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug)
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
@@ -469,6 +471,7 @@ class WorkerBase(ABC):
|
||||
|
||||
else:
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
|
||||
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
|
||||
|
||||
if is_last_stage and self.criterion:
|
||||
@@ -495,7 +498,6 @@ class WorkerBase(ABC):
|
||||
stage_input_kwargs,
|
||||
stage_outputs,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
# if not forward_only, do the backward
|
||||
if not forward_only:
|
||||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||
@@ -521,19 +523,19 @@ class WorkerBase(ABC):
|
||||
if use_checkpoint:
|
||||
stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)]
|
||||
|
||||
# take tensor only (for only tensor can do backward)
|
||||
stage_outputs_tensors = []
|
||||
pytree_map(stage_outputs, stage_outputs_tensors.append, process_types=torch.Tensor)
|
||||
|
||||
# overlap recompute and future.wait
|
||||
grad_tensors = get_real_args_kwargs(args)
|
||||
if not is_last_stage:
|
||||
grad_tensors = get_real_args_kwargs(args)
|
||||
else:
|
||||
grad_tensors = None
|
||||
|
||||
# print('rank', self.pp_rank, tensor_shape_list(stage_outputs_tensors), tensor_shape_list(grad_tensors))
|
||||
autograd.backward(stage_outputs_tensors, grad_tensors=grad_tensors)
|
||||
# take tensor only (for only tensor can do backward)
|
||||
stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor)
|
||||
grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor)
|
||||
|
||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||
|
||||
# collect grad of input tensor
|
||||
# there is a hypothesis that node in kwargs cann't be an non-leaf node in graph
|
||||
# so we don't need to save the grad of node in kwargs.
|
||||
consume_result = []
|
||||
if not is_first_stage:
|
||||
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||
|
@@ -110,7 +110,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
|
||||
if chunk > 1:
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
|
||||
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
|
||||
use_1F1B = True
|
||||
|
||||
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
|
@@ -20,7 +20,8 @@ def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] =
|
||||
Args:
|
||||
obj (:class:`Any`): object to process
|
||||
fn (:class:`Callable`): a function to process subobject in obj
|
||||
process_types(:class: `type | tuple[type]`): types to determine the type to process
|
||||
process_types (:class: `type | tuple[type]`): types to determine the type to process
|
||||
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
|
||||
|
||||
Returns:
|
||||
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
|
||||
@@ -59,6 +60,20 @@ def type_detail(obj):
|
||||
return pytree_map(obj, lambda x: type(x), map_all=True)
|
||||
|
||||
|
||||
def pytree_filter(fn, obj, process_types):
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
filters = []
|
||||
|
||||
def condition_append(obj):
|
||||
if fn(obj):
|
||||
filters.append(obj)
|
||||
|
||||
pytree_map(obj, fn=condition_append, process_types=process_types)
|
||||
return filters
|
||||
|
||||
|
||||
def get_real_args_kwargs(args_or_kwargs):
|
||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||
# TODO : combine producer and consumer
|
||||
|
Reference in New Issue
Block a user