[pipeline/fix-bug] num_microbatches support any integrate | stable chimera | launch tool for rpc pp framework (#1684)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera

* [pipeline/chimera] test chimera | fix bug of initializing

* [pipeline/pytree] add pytree to process args and kwargs | provide  to process args and kwargs after forward

* [pipeline/fix-bug] num_microbatches support any integrate | stable chimera | launch tool for rpc pp framework
This commit is contained in:
Kirigaya Kazuto
2022-10-10 16:01:02 +08:00
committed by GitHub
parent e5ab6be72e
commit 0df5034a36
4 changed files with 98 additions and 24 deletions

View File

@@ -3,9 +3,7 @@ from enum import Enum
from typing import List, Any, Tuple, Dict, Callable
from functools import partial
from abc import ABC, abstractmethod
import sys
import os
import time
import math
import inspect
import torch
@@ -831,13 +829,16 @@ class PipelineEngineBase(ABC, nn.Module):
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
batch_lengths = get_batch_lengths(batch)
batch_length = batch_lengths[0]
if labels is not None and not forward_only:
assert hasattr(
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
num_microbatches = self.num_microbatches
microbatch_size = batch_lengths[0] // num_microbatches
assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal"
microbatch_size = math.ceil(batch_length / num_microbatches)
device = self.device
# If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks'
@@ -852,7 +853,7 @@ class PipelineEngineBase(ABC, nn.Module):
# to prevent exceed of wait limitations
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
batch_start = microbatch_size * microbatch_id
batch_end = batch_start + microbatch_size
batch_end = min(batch_start + microbatch_size, batch_length)
# set input
microbatch = split_batch(batch, batch_start, batch_end, device)