mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,9 +3,9 @@ from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, Pipeli
|
||||
from .stage_manager import PipelineStageManager
|
||||
|
||||
__all__ = [
|
||||
'PipelineSchedule',
|
||||
'OneForwardOneBackwardSchedule',
|
||||
'InterleavedSchedule',
|
||||
'PipelineP2PCommunication',
|
||||
'PipelineStageManager',
|
||||
"PipelineSchedule",
|
||||
"OneForwardOneBackwardSchedule",
|
||||
"InterleavedSchedule",
|
||||
"PipelineP2PCommunication",
|
||||
"PipelineStageManager",
|
||||
]
|
||||
|
@@ -29,11 +29,11 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
||||
Any: object after unpickled
|
||||
"""
|
||||
buf = tensor.numpy().tobytes()[:tensor_size]
|
||||
if b'cuda' in buf:
|
||||
if b"cuda" in buf:
|
||||
buf_array = bytearray(buf)
|
||||
device_index = torch.cuda.current_device()
|
||||
# There might be more than one output tensors during forward
|
||||
for cuda_str in re.finditer(b'cuda', buf_array):
|
||||
for cuda_str in re.finditer(b"cuda", buf_array):
|
||||
pos = cuda_str.start()
|
||||
buf_array[pos + 5] = 48 + device_index
|
||||
buf = bytes(buf_array)
|
||||
@@ -45,10 +45,9 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
||||
return unpickle
|
||||
|
||||
|
||||
def _broadcast_object_list(object_list: List[Any],
|
||||
src: int,
|
||||
group: ProcessGroup,
|
||||
device: Optional[Union[torch.device, str, int]] = None):
|
||||
def _broadcast_object_list(
|
||||
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
||||
):
|
||||
"""This is a modified version of the broadcast_object_list in torch.distribution
|
||||
The only difference is that object will be move to correct device after unpickled.
|
||||
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
|
||||
@@ -99,8 +98,8 @@ def _broadcast_object_list(object_list: List[Any],
|
||||
if my_rank == src:
|
||||
object_tensor = torch.cat(tensor_list)
|
||||
else:
|
||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
|
||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
|
||||
@@ -114,7 +113,7 @@ def _broadcast_object_list(object_list: List[Any],
|
||||
|
||||
if my_rank != src:
|
||||
for i, obj_size in enumerate(object_sizes_tensor):
|
||||
obj_view = object_tensor[offset:offset + obj_size]
|
||||
obj_view = object_tensor[offset : offset + obj_size]
|
||||
obj_view = obj_view.type(torch.uint8)
|
||||
if obj_view.device != torch.device("cpu"):
|
||||
obj_view = obj_view.cpu()
|
||||
@@ -123,8 +122,10 @@ def _broadcast_object_list(object_list: List[Any],
|
||||
unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size)
|
||||
|
||||
# unconsistence in device
|
||||
if isinstance(unpickle_object,
|
||||
torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
|
||||
if (
|
||||
isinstance(unpickle_object, torch.Tensor)
|
||||
and unpickle_object.device.index != torch.cuda.current_device()
|
||||
):
|
||||
unpickle_object = unpickle_object.cuda()
|
||||
|
||||
object_list[i] = unpickle_object
|
||||
@@ -160,7 +161,6 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
||||
|
||||
|
||||
class PipelineP2PCommunication:
|
||||
|
||||
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
|
||||
@@ -192,8 +192,9 @@ class PipelineP2PCommunication:
|
||||
if next_rank is None:
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
output_tensor_grad = _recv_object(next_rank, cur_rank,
|
||||
self.stage_manager.get_p2p_process_group(next_rank, cur_rank))
|
||||
output_tensor_grad = _recv_object(
|
||||
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank)
|
||||
)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
|
@@ -3,7 +3,7 @@ from .interleaved_pp import InterleavedSchedule
|
||||
from .one_f_one_b import OneForwardOneBackwardSchedule
|
||||
|
||||
__all__ = [
|
||||
'PipelineSchedule',
|
||||
'OneForwardOneBackwardSchedule',
|
||||
'InterleavedSchedule',
|
||||
"PipelineSchedule",
|
||||
"OneForwardOneBackwardSchedule",
|
||||
"InterleavedSchedule",
|
||||
]
|
||||
|
@@ -4,24 +4,15 @@ from typing import Any, List, Optional, Tuple
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import (
|
||||
SUPPORTED_NODES,
|
||||
LeafSpec,
|
||||
TreeSpec,
|
||||
_is_leaf,
|
||||
_register_pytree_node,
|
||||
tree_flatten,
|
||||
tree_map,
|
||||
tree_unflatten,
|
||||
)
|
||||
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
# this register are for torch under version 1.13.1, maybe removed in the future
|
||||
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]:
|
||||
def _odict_flatten(d: "OrderedDict[Any, Any]") -> Tuple[List[Any], Any]:
|
||||
return list(d.values()), list(d.keys())
|
||||
|
||||
|
||||
def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]':
|
||||
def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]":
|
||||
return OrderedDict((key, value) for key, value in zip(context, values))
|
||||
|
||||
|
||||
@@ -45,7 +36,7 @@ def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
|
||||
|
||||
# Recursively flatten the children
|
||||
result: List[Any] = []
|
||||
children_specs: List['TreeSpec'] = []
|
||||
children_specs: List["TreeSpec"] = []
|
||||
for child in child_pytrees:
|
||||
flat, child_spec = tree_flatten_hf(child)
|
||||
result += flat
|
||||
@@ -87,7 +78,7 @@ def get_batch_size(batch: Any) -> int:
|
||||
for data in data_list:
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.size(0)
|
||||
raise RuntimeError('No tensor found in the batch')
|
||||
raise RuntimeError("No tensor found in the batch")
|
||||
|
||||
|
||||
def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any:
|
||||
@@ -104,7 +95,7 @@ def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any:
|
||||
|
||||
def _get_tensor_slice(x: Any):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x[start:start + micro_batch_size]
|
||||
return x[start : start + micro_batch_size]
|
||||
return x
|
||||
|
||||
return tree_map(_get_tensor_slice, batch)
|
||||
@@ -175,7 +166,7 @@ def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
|
||||
|
||||
for elem_batch in zip(*flattened_data):
|
||||
if isinstance(elem_batch[0], torch.Tensor):
|
||||
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
|
||||
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
|
||||
merged_data.append(None)
|
||||
else:
|
||||
merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
|
||||
|
@@ -8,17 +8,18 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
class PipelineSchedule:
|
||||
|
||||
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
|
||||
def forward_backward_step(self,
|
||||
model: Module,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[[Any, Any], Tensor],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False) -> dict:
|
||||
def forward_backward_step(
|
||||
self,
|
||||
model: Module,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[[Any, Any], Tensor],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
"""Forward and backward step for pipeline training.
|
||||
|
||||
Args:
|
||||
|
@@ -16,11 +16,11 @@ from .base import PipelineSchedule
|
||||
|
||||
|
||||
class InterleavedSchedule(PipelineSchedule):
|
||||
|
||||
def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None:
|
||||
self.num_model_chunks = num_model_chunks
|
||||
assert num_microbatches % self.num_model_chunks == 0, \
|
||||
"Number of microbatches should be an integer multiple of number of model chunks"
|
||||
assert (
|
||||
num_microbatches % self.num_model_chunks == 0
|
||||
), "Number of microbatches should be an integer multiple of number of model chunks"
|
||||
super().__init__(stage_manager)
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatches = num_microbatches
|
||||
@@ -42,8 +42,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||
assert self.batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
|
||||
def load_micro_batch(self, model_chunk_id: int) -> Any:
|
||||
@@ -72,7 +71,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks)
|
||||
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
|
||||
if not forward:
|
||||
model_chunk_id = (self.num_model_chunks - model_chunk_id - 1)
|
||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def is_first_stage(self, model_chunk_id: int) -> bool:
|
||||
@@ -161,13 +160,15 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
if not self.is_first_stage(model_chunk_id):
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
|
||||
def forward_step(self,
|
||||
model_chunk: Module,
|
||||
model_chunk_id: int,
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]:
|
||||
def forward_step(
|
||||
self,
|
||||
model_chunk: Module,
|
||||
model_chunk_id: int,
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None,
|
||||
) -> Union[torch.Tensor, dict]:
|
||||
"""Forward one step of the pipeline
|
||||
Args:
|
||||
model (Module): Model Chunk to be run
|
||||
@@ -195,8 +196,13 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
else:
|
||||
return output_obj
|
||||
|
||||
def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]:
|
||||
def backward_step(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
) -> Optional[dict]:
|
||||
"""Backward one step of the pipeline
|
||||
|
||||
Args:
|
||||
@@ -235,13 +241,15 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
input_obj_grad[k] = v.grad
|
||||
return input_obj_grad
|
||||
|
||||
def forward_backward_step(self,
|
||||
model_chunk: Module,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False) -> dict:
|
||||
def forward_backward_step(
|
||||
self,
|
||||
model_chunk: Module,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
|
||||
Args:
|
||||
@@ -321,7 +329,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True)
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
if forward_only:
|
||||
@@ -369,4 +377,4 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
return {'loss': accum_loss, 'outputs': outputs}
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
@@ -25,11 +25,12 @@ from .base import PipelineSchedule
|
||||
|
||||
|
||||
class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
|
||||
def __init__(self,
|
||||
stage_manager: PipelineStageManager,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
stage_manager: PipelineStageManager,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
) -> None:
|
||||
"""1F1B pipeline schedule.
|
||||
|
||||
Args:
|
||||
@@ -38,8 +39,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
|
||||
"""
|
||||
super().__init__(stage_manager)
|
||||
assert num_microbatches is not None or microbatch_size is not None, \
|
||||
"Either num_microbatches or microbatch_size should be provided"
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "Either num_microbatches or microbatch_size should be provided"
|
||||
self.comm = PipelineP2PCommunication(stage_manager)
|
||||
self.num_microbatches = num_microbatches
|
||||
self.microbatch_size = microbatch_size
|
||||
@@ -62,12 +64,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
self.batch_size = get_batch_size(batch)
|
||||
self.microbatch_offset = 0
|
||||
if not self._use_microbatch_size:
|
||||
assert self.batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
assert (
|
||||
self.batch_size % self.num_microbatches == 0
|
||||
), "Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
else:
|
||||
assert self.batch_size % self.microbatch_size == 0, \
|
||||
"Batch size should divided by the microbatch size"
|
||||
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
|
||||
self.num_microbatches = self.batch_size // self.microbatch_size
|
||||
|
||||
def load_micro_batch(self) -> Any:
|
||||
@@ -136,12 +138,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
self.comm.send_backward(input_object, prev_rank)
|
||||
|
||||
def forward_step(self,
|
||||
model: Module,
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]:
|
||||
def forward_step(
|
||||
self,
|
||||
model: Module,
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None,
|
||||
) -> Union[torch.Tensor, dict]:
|
||||
"""Forward one step of the pipeline
|
||||
|
||||
Args:
|
||||
@@ -159,7 +163,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
|
||||
output_obj = model_forward(model, micro_batch, input_obj)
|
||||
if self.stage_manager.is_last_stage():
|
||||
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatches
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
@@ -169,8 +172,13 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
else:
|
||||
return output_obj
|
||||
|
||||
def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]:
|
||||
def backward_step(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
) -> Optional[dict]:
|
||||
"""Backward one step of the pipeline
|
||||
|
||||
Args:
|
||||
@@ -208,13 +216,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
input_obj_grad[k] = v.grad
|
||||
return input_obj_grad
|
||||
|
||||
def forward_backward_step(self,
|
||||
model: Module,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False) -> dict:
|
||||
def forward_backward_step(
|
||||
self,
|
||||
model: Module,
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
return_loss: bool = False,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
|
||||
Args:
|
||||
@@ -273,7 +283,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
|
||||
if forward_only:
|
||||
@@ -316,5 +326,5 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
if outputs is not None:
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0))
|
||||
return {'loss': accum_loss, 'outputs': outputs}
|
||||
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
@@ -1,4 +1,3 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch.distributed as dist
|
||||
@@ -28,13 +27,11 @@ class PipelineStageManager:
|
||||
# init prev and next coord
|
||||
coord = self.pg_mesh.coordinate()
|
||||
# the prev rank of rank0 is the last rank
|
||||
prev_coord = coord[: self.pipeline_axis] + \
|
||||
(coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:]
|
||||
self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode='wrap')
|
||||
prev_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1 :]
|
||||
self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode="wrap")
|
||||
# the next rank of the last rank is rank0
|
||||
next_coord = coord[: self.pipeline_axis] + \
|
||||
(coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:]
|
||||
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode='wrap')
|
||||
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
|
||||
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
|
||||
|
||||
# init p2p process groups
|
||||
stages = list(range(self.num_stages))
|
||||
|
Reference in New Issue
Block a user