mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-14 07:56:29 +00:00
[pipeline/fix-bug] num_microbatches support any integrate | stable chimera | launch tool for rpc pp framework (#1684)
* [pipeline/tuning] improve dispatch performance both time and space cost * [pipeline/converge] add interface for testing convergence * [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style * Update PipelineBase.py * [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera * [pipeline/chimera] test chimera | fix bug of initializing * [pipeline/pytree] add pytree to process args and kwargs | provide to process args and kwargs after forward * [pipeline/fix-bug] num_microbatches support any integrate | stable chimera | launch tool for rpc pp framework
This commit is contained in:
@@ -3,9 +3,7 @@ from enum import Enum
|
||||
from typing import List, Any, Tuple, Dict, Callable
|
||||
from functools import partial
|
||||
from abc import ABC, abstractmethod
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import math
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
@@ -831,13 +829,16 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
|
||||
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
||||
batch_lengths = get_batch_lengths(batch)
|
||||
batch_length = batch_lengths[0]
|
||||
|
||||
if labels is not None and not forward_only:
|
||||
assert hasattr(
|
||||
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
|
||||
|
||||
num_microbatches = self.num_microbatches
|
||||
microbatch_size = batch_lengths[0] // num_microbatches
|
||||
|
||||
assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal"
|
||||
microbatch_size = math.ceil(batch_length / num_microbatches)
|
||||
device = self.device
|
||||
|
||||
# If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks'
|
||||
@@ -852,7 +853,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
# to prevent exceed of wait limitations
|
||||
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
|
||||
batch_start = microbatch_size * microbatch_id
|
||||
batch_end = batch_start + microbatch_size
|
||||
batch_end = min(batch_start + microbatch_size, batch_length)
|
||||
|
||||
# set input
|
||||
microbatch = split_batch(batch, batch_start, batch_end, device)
|
||||
|
||||
Reference in New Issue
Block a user