mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[pipeline] refactor 1f1b schedule (#4115)
* [api] update optimizer wrapper to fit pipeline * [pipeline] add base schedule * [pipeline] add 1f1b schedule * [test] add pipeline schedule utils test * [pipeline] fix import
This commit is contained in:
129
colossalai/pipeline/schedule/_utils.py
Normal file
129
colossalai/pipeline/schedule/_utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
|
||||
"""Move object to device if it is a tensor.
|
||||
|
||||
Args:
|
||||
x (Any): Object to be moved.
|
||||
device (Optional[torch.device], optional): Target device. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Any: Moved object.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.to(device)
|
||||
return x
|
||||
|
||||
|
||||
def get_batch_size(batch: Any) -> int:
|
||||
"""Get the batch size (size of dimension-0) of the first tensor in the batch.
|
||||
|
||||
Args:
|
||||
batch (Any): Batch to be inspected.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no tensor is found in the batch.
|
||||
|
||||
Returns:
|
||||
int: Batch size.
|
||||
"""
|
||||
data_list, _ = tree_flatten(batch)
|
||||
for data in data_list:
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.size(0)
|
||||
raise RuntimeError('No tensor found in the batch')
|
||||
|
||||
|
||||
def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any:
|
||||
"""Get a micro batch of the original batch.
|
||||
|
||||
Args:
|
||||
batch (Any): Batch to be sliced.
|
||||
start (int): Start index of the micro batch.
|
||||
micro_batch_size (int): Size of the micro batch.
|
||||
|
||||
Returns:
|
||||
Any: Target micro batch.
|
||||
"""
|
||||
|
||||
def _get_tensor_slice(x: Any):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x[start:start + micro_batch_size]
|
||||
return x
|
||||
|
||||
return tree_map(_get_tensor_slice, batch)
|
||||
|
||||
|
||||
def model_forward(model: Module, data: Any, internal_inputs: Optional[dict]) -> Any:
|
||||
"""Call model forward function with data and internal inputs.
|
||||
|
||||
Args:
|
||||
model (Module): Model to be called.
|
||||
data (Any): Data loaded from data iterator.
|
||||
internal_inputs (Optional[dict]): Data from previous stage. It must be a dict or None if it's the first stage.
|
||||
|
||||
Returns:
|
||||
Any: Outputs of the model.
|
||||
"""
|
||||
if internal_inputs is None:
|
||||
internal_inputs = {}
|
||||
if isinstance(data, (list, tuple)):
|
||||
return model(*data, **internal_inputs)
|
||||
elif isinstance(data, dict):
|
||||
return model(**data, **internal_inputs)
|
||||
return model(data, **internal_inputs)
|
||||
|
||||
|
||||
def retain_grad(x: Any) -> None:
|
||||
"""Call retain_grad() on a tensor.
|
||||
|
||||
Args:
|
||||
x (Any): Object to be called.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x.retain_grad()
|
||||
|
||||
|
||||
def detach(x: Any) -> Any:
|
||||
"""Call detach() on a tensor.
|
||||
|
||||
Args:
|
||||
x (Any): Object to be called.
|
||||
|
||||
Returns:
|
||||
Any: The detached object.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.detach()
|
||||
return x
|
||||
|
||||
|
||||
def merge_batch(data: List[Any]) -> Any:
|
||||
"""Merge micro batches into a batch.
|
||||
|
||||
Args:
|
||||
data (List[Any]): A list of micro batches.
|
||||
|
||||
Returns:
|
||||
Any: Merge batch.
|
||||
"""
|
||||
if len(data) == 0:
|
||||
return
|
||||
flattened_data = []
|
||||
tree_spec = None
|
||||
for d in data:
|
||||
elems, tree_spec = tree_flatten(d)
|
||||
flattened_data.append(elems)
|
||||
merged_data = []
|
||||
for elem_batch in zip(*flattened_data):
|
||||
if isinstance(elem_batch[0], torch.Tensor):
|
||||
merged_data.append(torch.cat(elem_batch, dim=0))
|
||||
else:
|
||||
merged_data.append(list(elem_batch))
|
||||
return tree_unflatten(merged_data, tree_spec)
|
Reference in New Issue
Block a user