mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b7699
.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
@@ -5,125 +5,85 @@ from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class BaseSchedule(ABC):
|
||||
"""A basic helper class to control the process of training or evaluation.
|
||||
It mainly composes of forward_backward_step for gradient backward and
|
||||
optimizer_step for parameters update.
|
||||
For the convenience to enable FP16, we aggreate all codes that contain the
|
||||
control of FP16 in class schedule.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
self.logger = get_global_dist_logger()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_steps(self):
|
||||
"""The number of batches in training or evaluation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def initialize(self,
|
||||
dataloader=None,
|
||||
model=None,
|
||||
criterion=None,
|
||||
optimizer=None,
|
||||
lr_scheduler=None):
|
||||
"""Initializes the schedule and set parameters before running.
|
||||
|
||||
:param dataloader: DataLoader in training or evaluation
|
||||
:param model: The neural network model
|
||||
:param criterion: Criterion for calculating loss
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
:param lr_scheduler: Learning rate scheduler in the process
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
assert model is not None, "Schedule requires a model"
|
||||
self.model = model
|
||||
assert criterion is not None, "Schedule requires a criterion"
|
||||
self.criterion = criterion
|
||||
assert optimizer is not None, "Schedule requires an optimizer"
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.initialized = True
|
||||
|
||||
def check_initialized(self):
|
||||
"""Checks whether the schedule is initialized.
|
||||
"""
|
||||
assert self.initialized, \
|
||||
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
|
||||
|
||||
def load_batch(self):
|
||||
"""Loads a batch of dataset. It returns the data and labels which are
|
||||
already in the same GPU as where the model's.
|
||||
|
||||
:return: (data, label)
|
||||
:rtype: (Tensor, Tensor)
|
||||
"""
|
||||
self.check_initialized()
|
||||
if self.data_iter is None:
|
||||
raise RuntimeError('Dataloader is not defined.')
|
||||
data, label = next(self.data_iter)
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
@staticmethod
|
||||
def _move_tensor(element):
|
||||
if torch.is_tensor(element):
|
||||
if not element.is_cuda:
|
||||
return element.to(get_current_device()).detach()
|
||||
return element
|
||||
|
||||
def _move_to_device(self, data):
|
||||
if isinstance(data, (
|
||||
tuple,
|
||||
list,
|
||||
)):
|
||||
data = tuple([
|
||||
d.to(get_current_device()).detach() for d in data
|
||||
if torch.is_tensor(d)
|
||||
])
|
||||
if isinstance(data, (tuple, list)):
|
||||
data = tuple([self._move_tensor(d) for d in data])
|
||||
elif torch.is_tensor(data):
|
||||
data = data.to(get_current_device()).detach()
|
||||
return data
|
||||
|
||||
def train(self, dataloader=None, mode=True):
|
||||
"""Sets the dataloader to be used and turn the model to
|
||||
training or evaluation mode.
|
||||
def load_batch(self, data_iter):
|
||||
"""Loads a batch from data iterator. It returns the data and labels which are
|
||||
already in the same GPU as where the model's.
|
||||
|
||||
:param dataloader: Dataloader to be used
|
||||
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
|
||||
:return: (data, label)
|
||||
:rtype: (Tensor, Tensor)
|
||||
"""
|
||||
self.check_initialized()
|
||||
if mode:
|
||||
self.model.train()
|
||||
else:
|
||||
self.model.eval()
|
||||
if dataloader is not None:
|
||||
self.dataloader = dataloader
|
||||
self.data_iter = iter(dataloader)
|
||||
if data_iter is None:
|
||||
raise RuntimeError('Dataloader is not defined.')
|
||||
data, label = next(data_iter)
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
|
||||
def zero_grad(self, forward_only=False):
|
||||
"""Cleans gradients with the optimizer.
|
||||
"""
|
||||
if not forward_only:
|
||||
self.check_initialized()
|
||||
self.optimizer.zero_grad()
|
||||
def initialize(self, model, optimizer):
|
||||
"""Initializes the model and the optimizer before training.
|
||||
This is often used in FP16 training.
|
||||
|
||||
def get_lr(self):
|
||||
"""Returns the current learning rate.
|
||||
:param model: The neural network model
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
"""
|
||||
if self.lr_scheduler is not None:
|
||||
return self.lr_scheduler.get_lr()[0]
|
||||
else:
|
||||
return self.optimizer.param_groups[0]['lr']
|
||||
|
||||
def step(self):
|
||||
"""Updates the parameters and learning rate with the optimizer.
|
||||
"""
|
||||
self.check_initialized()
|
||||
self.optimizer.step()
|
||||
# update lr scheduler
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
return model, optimizer
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(self, forward_only=False, return_loss=True):
|
||||
def forward_backward_step(self,
|
||||
data_iter,
|
||||
model,
|
||||
criterion,
|
||||
optimizer=None,
|
||||
forward_only=False,
|
||||
grad_accum_size: int = 1,
|
||||
return_loss=True):
|
||||
"""The process function over a batch of dataset for training or evaluation.
|
||||
|
||||
:param forward_only: If True, the process won't include backward.
|
||||
:param return_loss: If False, the loss won't be returned.
|
||||
:param data_iter: Data iterator of the dataset
|
||||
:param model: Model used in training or evaluation
|
||||
:param optimizer: Optimizer used in training or evaluation
|
||||
:param criterion: Loss function
|
||||
:param forward_only: If True, the process won't include backward
|
||||
:param grad_accum_size: Steps of gradient accumulation
|
||||
:param return_loss: If False, the loss won't be returned
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
|
||||
"""Updates the parameters with the optimizer.
|
||||
|
||||
:param model: The neural network model
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
:param grad_clipping: The norm of gradient clipping
|
||||
:type grad_clipping: float, optional
|
||||
"""
|
||||
pass
|
||||
|
@@ -4,19 +4,24 @@
|
||||
try:
|
||||
import apex.amp as apex_amp
|
||||
except:
|
||||
print('apex is required for mixed precision training')
|
||||
pass
|
||||
|
||||
try:
|
||||
import torch.cuda.amp as torch_amp
|
||||
except:
|
||||
print('PyTorch amp is not supported with the current PyTorch version')
|
||||
pass
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.amp_type import AMP_TYPE
|
||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from ._utils import convert_to_fp16
|
||||
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
|
||||
from ._base_schedule import BaseSchedule
|
||||
from ._utils import convert_to_fp16, convert_to_fp32
|
||||
from ..amp import AMP_TYPE, GradScaler
|
||||
|
||||
|
||||
class NoPipelineSchedule(BaseSchedule):
|
||||
@@ -30,6 +35,7 @@ class NoPipelineSchedule(BaseSchedule):
|
||||
:type amp_type: AMP_TYPE
|
||||
:type amp_config: dict
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
amp_type: AMP_TYPE = None,
|
||||
@@ -41,12 +47,6 @@ class NoPipelineSchedule(BaseSchedule):
|
||||
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
|
||||
'unrecognised value for argument fp16, it can only be None, torch or apex'
|
||||
|
||||
# LSG: check compatibility
|
||||
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
|
||||
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
|
||||
ParallelMode.TENSOR) > 1:
|
||||
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
|
||||
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
|
||||
self.use_zero_level_2_3 = False
|
||||
|
||||
if amp_type is not None:
|
||||
@@ -79,107 +79,110 @@ class NoPipelineSchedule(BaseSchedule):
|
||||
self.fp16 = False
|
||||
self.amp_type = None
|
||||
|
||||
@property
|
||||
def num_steps(self):
|
||||
return len(self.dataloader)
|
||||
|
||||
def initialize(self,
|
||||
dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=None):
|
||||
super().initialize(dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=lr_scheduler)
|
||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
def initialize(self, model: nn.Module, optimizer: Optimizer):
|
||||
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
self.use_zero_level_2_3 = True
|
||||
assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
|
||||
assert self.amp_type != AMP_TYPE.PARALLEL, \
|
||||
'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
|
||||
|
||||
if self.fp16:
|
||||
if self.amp_type == AMP_TYPE.TORCH:
|
||||
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg)
|
||||
self._torch_amp_scaler = GradScaler(**self.amp_cfg)
|
||||
elif self.amp_type == AMP_TYPE.APEX:
|
||||
self.model, self.optimizer = apex_amp.initialize(
|
||||
self.model, self.optimizer, **self.amp_cfg)
|
||||
model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg)
|
||||
|
||||
def forward_backward_step(self, forward_only=False, return_loss=True):
|
||||
return model, optimizer
|
||||
|
||||
def forward_backward_step(self,
|
||||
data_iter: Iterable,
|
||||
model: nn.Module,
|
||||
criterion: nn.modules.loss._Loss,
|
||||
optimizer: Optimizer = None,
|
||||
forward_only: bool = False,
|
||||
grad_accum_size: int = 1,
|
||||
return_loss: bool = True):
|
||||
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
||||
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
|
||||
:param model: Model for training and inference
|
||||
:param criterion: Loss function for training
|
||||
:param optimizer: Optimizer used for training
|
||||
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
|
||||
:param grad_accum_size: The number of iterations for gradient accumulation
|
||||
:param return_loss: Loss will be returned if True
|
||||
:type data_iter: Iterator
|
||||
:type model: torch.nn.Module
|
||||
:type criterion: torch.nn.modules.loss._Loss
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:type forward_only: bool, optional
|
||||
:type grad_accum_size: int
|
||||
:type return_loss: bool, optional
|
||||
:return: (output, label, loss)
|
||||
"""
|
||||
assert forward_only or return_loss, \
|
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||
|
||||
data, label = self.load_batch()
|
||||
data, label = self.load_batch(data_iter)
|
||||
loss = None
|
||||
|
||||
# LSG: leave for debug, make sure dataloader is deterministic
|
||||
# if forward_only:
|
||||
# img = data[0]
|
||||
# rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
# world_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
# group = gpc.get_group(ParallelMode.DATA)
|
||||
# input_list = [img.clone() for _ in range(world_size)]
|
||||
# output_list = [torch.empty_like(img) for _ in range(world_size)]
|
||||
# output_list[rank] = img.clone()
|
||||
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
|
||||
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
|
||||
|
||||
# forward
|
||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||
with torch_amp.autocast():
|
||||
output = self.model(*data)
|
||||
output = model(*data)
|
||||
if not isinstance(output, (tuple, list)):
|
||||
output = (output,)
|
||||
if return_loss:
|
||||
loss = self.criterion(*output, *label)
|
||||
loss = criterion(*output, *label)
|
||||
else:
|
||||
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||
data = convert_to_fp16(data)
|
||||
|
||||
output = self.model(*data)
|
||||
output = model(*data)
|
||||
|
||||
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||
output = convert_to_fp32(output)
|
||||
|
||||
if not isinstance(output, (tuple, list)):
|
||||
output = (output,)
|
||||
if return_loss:
|
||||
loss = self.criterion(*output, *label)
|
||||
loss = criterion(*output, *label)
|
||||
|
||||
loss /= grad_accum_size
|
||||
|
||||
if not forward_only:
|
||||
# backward
|
||||
if self.use_zero_level_2_3:
|
||||
self.optimizer.backward(loss)
|
||||
optimizer.backward(loss)
|
||||
elif self.fp16:
|
||||
if self.amp_type == AMP_TYPE.APEX:
|
||||
with apex_amp.scale_loss(loss,
|
||||
self.optimizer) as scaled_loss:
|
||||
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
elif self.amp_type == AMP_TYPE.TORCH:
|
||||
self._torch_amp_scaler.scale(loss).backward()
|
||||
elif self.amp_type == AMP_TYPE.PARALLEL:
|
||||
loss = self.optimizer.scale_loss(loss)
|
||||
loss = optimizer.scale_loss(loss)
|
||||
loss.backward()
|
||||
# scale back to display the original value in logs
|
||||
loss.div_(self.optimizer.grad_scaler.scale)
|
||||
loss.div_(optimizer.grad_scaler.scale)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if return_loss:
|
||||
return output, label, loss
|
||||
return output, label, loss * grad_accum_size
|
||||
else:
|
||||
return output, None, None
|
||||
|
||||
def step(self):
|
||||
def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
|
||||
# step optimizer
|
||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||
self._torch_amp_scaler.step(self.optimizer)
|
||||
if grad_clipping > 0.0:
|
||||
self._torch_amp_scaler.unscale_(optimizer)
|
||||
clip_grad_norm_fp32(model.parameters(), grad_clipping)
|
||||
self._torch_amp_scaler.step(optimizer)
|
||||
self._torch_amp_scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
# update lr scheduler
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0:
|
||||
clip_grad_norm_fp32(model.parameters(), grad_clipping)
|
||||
optimizer.step()
|
||||
|
@@ -15,7 +15,7 @@ from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||
from colossalai.utils import get_current_device
|
||||
from ._base_schedule import BaseSchedule
|
||||
from ._utils import convert_to_fp16
|
||||
from ..amp_type import AMP_TYPE
|
||||
from ..amp import AMP_TYPE
|
||||
|
||||
|
||||
def squeeze(x: Union[Tensor, tuple, list]):
|
||||
@@ -93,12 +93,11 @@ class PipelineSchedule(BaseSchedule):
|
||||
)
|
||||
|
||||
# Pipeline schedule just puts data in memory
|
||||
def load_batch(self):
|
||||
self.check_initialized()
|
||||
if self.data_iter is None:
|
||||
def load_batch(self, data_iter):
|
||||
if data_iter is None:
|
||||
raise RuntimeError('Dataloader is not defined.')
|
||||
self.batch_pos = 0
|
||||
data, label = next(self.data_iter)
|
||||
data, label = next(data_iter)
|
||||
self.batch_data, self.batch_label = \
|
||||
self._move_to_device(data), self._move_to_device(label)
|
||||
batch_size = self.batch_data.shape[0]
|
||||
@@ -117,23 +116,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
self.batch_pos += self.microbatch_size
|
||||
return (data,), (label,)
|
||||
|
||||
@property
|
||||
def num_steps(self):
|
||||
return len(self.dataloader)
|
||||
|
||||
def initialize(self,
|
||||
dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=None):
|
||||
super().initialize(dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=lr_scheduler)
|
||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
def initialize(self, model, optimizer):
|
||||
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||
raise TypeError(
|
||||
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
||||
)
|
||||
@@ -145,7 +129,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
'default tensor dtype is set to torch.half for fp16 training',
|
||||
ranks=[0])
|
||||
|
||||
def forward_step(self, input_tensor, return_tensors, return_loss=True):
|
||||
def forward_step(self, model, criterion, input_tensor, return_tensors,
|
||||
grad_accum_size, return_loss=True):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
@@ -156,14 +141,14 @@ class PipelineSchedule(BaseSchedule):
|
||||
if self.amp_type == AMP_TYPE.PARALLEL:
|
||||
input_tensor = convert_to_fp16(input_tensor)
|
||||
input_tensor = squeeze(input_tensor)
|
||||
output_tensor = self.model(input_tensor)
|
||||
output_tensor = model(input_tensor)
|
||||
output_tensor = squeeze(output_tensor)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
if return_loss:
|
||||
input_tensor, label = self.load_micro_batch()
|
||||
loss_reduced = self.criterion(output_tensor, *
|
||||
label) / self.num_microbatches
|
||||
loss_reduced = criterion(output_tensor, *label) \
|
||||
/ (self.num_microbatches * grad_accum_size)
|
||||
return_tensors.append(
|
||||
tuple((output_tensor, label[0], loss_reduced)))
|
||||
return loss_reduced
|
||||
@@ -174,7 +159,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
else:
|
||||
return output_tensor
|
||||
|
||||
def backward_step(self, input_tensor, output_tensor, output_tensor_grad):
|
||||
def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad):
|
||||
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
||||
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||
Returns the gradients with respect to the input tensor (None if first stage).
|
||||
@@ -187,7 +172,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
|
||||
# Backward pass.
|
||||
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
|
||||
output_tensor = self.optimizer.scale_loss(output_tensor)
|
||||
output_tensor = optimizer.scale_loss(output_tensor)
|
||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
|
||||
|
||||
# Collect the grad of the input_tensor.
|
||||
@@ -197,17 +182,24 @@ class PipelineSchedule(BaseSchedule):
|
||||
|
||||
return input_tensor_grad
|
||||
|
||||
def forward_backward_step(self, forward_only=True, return_loss=True):
|
||||
def forward_backward_step(self,
|
||||
data_iter,
|
||||
model,
|
||||
criterion,
|
||||
optimizer=None,
|
||||
forward_only=False,
|
||||
grad_accum_size: int = 1,
|
||||
return_loss=True):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
||||
|
||||
:return: (output, label, loss)
|
||||
"""
|
||||
|
||||
assert forward_only or return_loss, \
|
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||
|
||||
self.load_batch()
|
||||
self.load_batch(data_iter)
|
||||
num_warmup_microbatches = \
|
||||
(gpc.get_world_size(ParallelMode.PIPELINE) -
|
||||
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
||||
@@ -233,9 +225,11 @@ class PipelineSchedule(BaseSchedule):
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shape = recv_tensor_meta(ft_shape)
|
||||
input_tensor = recv_forward(ft_shape)
|
||||
output_tensor = self.forward_step(input_tensor,
|
||||
return_tensors,
|
||||
return_loss=return_loss)
|
||||
output_tensor = self.forward_step(
|
||||
model, criterion,
|
||||
input_tensor, return_tensors,
|
||||
grad_accum_size, return_loss=return_loss
|
||||
)
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
bt_shape = output_tensor.shape
|
||||
fs_checker = send_tensor_meta(output_tensor, fs_checker)
|
||||
@@ -257,9 +251,11 @@ class PipelineSchedule(BaseSchedule):
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
|
||||
output_tensor = self.forward_step(input_tensor,
|
||||
return_tensors,
|
||||
return_loss=return_loss)
|
||||
output_tensor = self.forward_step(
|
||||
model, criterion,
|
||||
input_tensor, return_tensors,
|
||||
grad_accum_size, return_loss=return_loss
|
||||
)
|
||||
if forward_only:
|
||||
send_forward(output_tensor)
|
||||
|
||||
@@ -279,9 +275,11 @@ class PipelineSchedule(BaseSchedule):
|
||||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
|
||||
input_tensor_grad = self.backward_step(input_tensor,
|
||||
output_tensor,
|
||||
output_tensor_grad)
|
||||
input_tensor_grad = self.backward_step(
|
||||
optimizer,
|
||||
input_tensor, output_tensor,
|
||||
output_tensor_grad
|
||||
)
|
||||
|
||||
if last_iteration:
|
||||
input_tensor = None
|
||||
@@ -298,9 +296,11 @@ class PipelineSchedule(BaseSchedule):
|
||||
|
||||
output_tensor_grad = recv_backward(bt_shape)
|
||||
|
||||
input_tensor_grad = self.backward_step(input_tensor,
|
||||
output_tensor,
|
||||
output_tensor_grad)
|
||||
input_tensor_grad = self.backward_step(
|
||||
optimizer,
|
||||
input_tensor, output_tensor,
|
||||
output_tensor_grad
|
||||
)
|
||||
|
||||
send_backward(input_tensor_grad)
|
||||
|
||||
@@ -309,8 +309,11 @@ class PipelineSchedule(BaseSchedule):
|
||||
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
||||
return (torch.cat(output, dim=0),
|
||||
torch.cat(label, dim=0),
|
||||
sum(loss))
|
||||
sum(loss) * grad_accum_size)
|
||||
else:
|
||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
||||
else:
|
||||
return tuple((None, None, None))
|
||||
|
||||
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
|
||||
optimizer.step()
|
||||
|
@@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
|
||||
else:
|
||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||
return ret
|
||||
|
||||
|
||||
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
|
||||
if isinstance(data, Tensor):
|
||||
ret = data.float()
|
||||
elif isinstance(data, (list, tuple)):
|
||||
ret = [val.float() for val in data]
|
||||
else:
|
||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||
return ret
|
||||
|
||||
|
Reference in New Issue
Block a user