diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 1dfd261d5..623160003 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -94,17 +94,23 @@ class ProcessGroupMesh: return np.unravel_index(rank, shape) @staticmethod - def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int: + def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int: """Convert a coordinate to a rank. + mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html. + with wrap, index out of range would be wrapped around. + For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2) Args: coords (Tuple[int, ...]): Coordinate to be converted. shape (Tuple[int, ...]): Shape of the process group mesh. + mode (Optional[str]): The mode for numpy.ravel_multi_index. Returns: int: Rank of the coordinate. """ - return np.ravel_multi_index(coord, shape) + + assert mode in ["raise", "wrap", "clip"] + return np.ravel_multi_index(coord, shape, mode) def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: """Get the process group with the given ranks. It the process group doesn't exist, it will be created. diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index af7a00b5c..aed85cf91 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -173,14 +173,10 @@ class PipelineP2PCommunication: Returns: Any: The input tensor or input tensor list. """ - if self.stage_manager.is_first_stage(): - input_tensor = None - else: - if prev_rank is None: - prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - input_tensor = _recv_object(prev_rank, cur_rank, - self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + cur_rank = self.stage_manager.get_rank() + input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) return input_tensor @@ -193,14 +189,11 @@ class PipelineP2PCommunication: Returns: Any: The input gradient tensor or gradient tensor list. """ - if self.stage_manager.is_last_stage(): - output_tensor_grad = None - else: - if next_rank is None: - next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - output_tensor_grad = _recv_object(next_rank, cur_rank, - self.stage_manager.get_p2p_process_group(next_rank, cur_rank)) + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + output_tensor_grad = _recv_object(next_rank, cur_rank, + self.stage_manager.get_p2p_process_group(next_rank, cur_rank)) return output_tensor_grad @@ -211,12 +204,10 @@ class PipelineP2PCommunication: output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ - if not self.stage_manager.is_last_stage(): - if next_rank is None: - next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - _send_object(output_object, cur_rank, next_rank, - self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + _send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) def send_backward(self, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. @@ -225,9 +216,7 @@ class PipelineP2PCommunication: input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor """ - if not self.stage_manager.is_first_stage(): - if prev_rank is None: - prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - _send_object(input_object, cur_rank, prev_rank, - self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + cur_rank = self.stage_manager.get_rank() + _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py new file mode 100644 index 000000000..35a33491b --- /dev/null +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -0,0 +1,370 @@ +from functools import partial +from typing import Any, Callable, Iterable, List, Optional, Union + +import torch +import torch.cuda +from torch.nn import Module +from torch.utils._pytree import tree_map + +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.utils.cuda import get_current_device + +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device +from .base import PipelineSchedule + + +class InterleavedSchedule(PipelineSchedule): + + def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None: + self.num_model_chunks = num_model_chunks + assert num_microbatches % self.num_model_chunks == 0, \ + "Number of microbatches should be an integer multiple of number of model chunks" + super().__init__(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager) + self.num_microbatches = num_microbatches + self.batch: Optional[Any] = None + self.batch_size: Optional[int] = None + self.microbatch_offset: Optional[int] = None + self.microbatch_size: Optional[int] = None + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + self.batch = batch + self.batch_size = get_batch_size(batch) + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + assert self.batch_size % self.num_microbatches == 0, \ + "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + + def load_micro_batch(self, model_chunk_id: int) -> Any: + """Load a micro batch from the current batch. + + Args: + microbatch_id (int): the current model chunk idx. + + Returns: + Any: Micro batch. + """ + micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return tree_map(partial(to_device, device=get_current_device()), micro_batch) + + def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: + """Helper method to get the model chunk ID given the iteration number. + + Args: + microbatch_id (int): the current microbatch idx + forward (bool): if is the forward process + + Returns: + int: The model chunk idx of the input microbatch_id + """ + microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) + model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages + if not forward: + model_chunk_id = (self.num_model_chunks - model_chunk_id - 1) + return model_chunk_id + + def is_first_stage(self, model_chunk_id: int) -> bool: + """Is the current virtual stage the first stage + + Args: + model_chunk_id (int): The current model chunk idx. + + Returns: + bool: Whether the current virtual stage is the first stage. + """ + if self.stage_manager.is_first_stage() and model_chunk_id == 0: + return True + return False + + def is_last_stage(self, model_chunk_id: int) -> bool: + """Is the current virtual stage the last stage + + Args: + model_chunk_id (int): The current model chunk idx. + + Returns: + bool: Whether the current virtual stage is the last stage. + """ + if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1: + return True + return False + + def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if self.is_first_stage(model_chunk_id): + input_tensor = None + else: + input_tensor = self.comm.recv_forward(prev_rank) + + return input_tensor + + def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if self.is_last_stage(model_chunk_id): + output_tensor_grad = None + else: + output_tensor_grad = self.comm.recv_backward(next_rank) + + return output_tensor_grad + + def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.is_last_stage(model_chunk_id): + self.comm.send_forward(output_object, next_rank) + + def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not self.is_first_stage(model_chunk_id): + self.comm.send_backward(input_object, prev_rank) + + def forward_step(self, + model_chunk: Module, + model_chunk_id: int, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + Args: + model (Module): Model Chunk to be run + input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. + criterion (Callable): Criterion to calculate loss. + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + + # for the first stage, input_obj is None + # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + + if self.is_last_stage(model_chunk_id): + loss = criterion(output_obj, micro_batch) / self.num_microbatches + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map(detach, output_obj)) + return loss + else: + return output_obj + + def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]: + """Backward one step of the pipeline + + Args: + optimizer (OptimizerWrapper): Optimizer to update the model + input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None. + output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor). + output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None. + + Returns: + Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None. + """ + + # Retain the grad on the input_obj. + tree_map(retain_grad, input_obj) + + # Backward pass. + if output_obj_grad is None: + optimizer.backward(output_obj) + else: + if "backward_tensor_keys" not in output_obj: + for k, grad in output_obj_grad.items(): + optimizer.backward_by_grad(output_obj[k], grad) + else: + for k, grad in output_obj_grad.items(): + output_obj[k].grad = grad + for k in output_obj["backward_tensor_keys"]: + tensor_to_backward = output_obj[k] + optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad) + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + input_obj_grad = {} + for k, v in input_obj.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + return input_obj_grad + + def forward_backward_step(self, + model_chunk: Module, + optimizer: OptimizerWrapper, + data_iter: Iterable, + criterion: Callable[..., Any], + return_loss: bool = False, + return_outputs: bool = False) -> dict: + """Runs interleaved 1F1B schedule, with communication between pipeline stages. + + Args: + model_chunk (List[Module]): Model Chunk to be trained. + optimizer (OptimizerWrapper): Optimizer to be used. + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + forward_only = not torch.is_grad_enabled() + + self.load_batch(data_iter) + num_model_chunks = len(model_chunk) + + # num_warmup_microbatches is the step when not all the processes are working + num_microbatches = self.num_microbatches * num_model_chunks + if forward_only: + num_warmup_microbatches = num_microbatches + else: + num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 + num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + + num_microbatches_remaining = num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + + if not forward_only: + input_objs = [[] for _ in range(num_model_chunks)] + output_objs = [[] for _ in range(num_model_chunks)] + + outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + + if return_loss and self.stage_manager.is_last_stage(): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # for ranks except the first one, get into recv state + # print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining) + input_obj = self.recv_forward(0) + input_objs[0].append(input_obj) + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + model_chunk_id = self.get_model_chunk_id(i, forward=True) + + # recv first on first rank to avoid sending or recving at the same time + if self.stage_manager.is_first_stage(): + input_obj = self.recv_forward(model_chunk_id) + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + self.send_forward(model_chunk_id, output_obj) + if not forward_only: + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) + else: + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + if not forward_only: + output_objs[model_chunk_id].append(output_obj) + self.send_forward(model_chunk_id, output_obj) + if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches: + break + else: + model_chunk_id = self.get_model_chunk_id(i + 1, forward=True) + + input_obj = self.recv_forward(model_chunk_id) + if not forward_only: + input_objs[model_chunk_id].append(input_obj) + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) + last_iteration = (i == (num_microbatches_remaining - 1)) + + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + if forward_only: + self.send_forward(model_chunk_id, output_obj) + + if not last_iteration: + input_obj = self.recv_forward(model_chunk_id) + + else: + self.send_forward(model_chunk_id, output_obj) + # Add input_obj and output_obj to end of list. + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) + + model_chunk_id = self.get_model_chunk_id(i, forward=False) + output_obj_grad = self.recv_backward(model_chunk_id) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + + # backward + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + else: + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True) + input_obj = self.recv_forward(model_chunk_id) + model_chunk_id = self.get_model_chunk_id(i, forward=False) + self.send_backward(model_chunk_id, input_obj_grad) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_microbatches_remaining, num_microbatches): + model_chunk_id = self.get_model_chunk_id(i, forward=False) + # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}") + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + + output_obj_grad = self.recv_backward(model_chunk_id) + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + self.send_backward(model_chunk_id, input_obj_grad) + + if outputs is not None: + outputs = merge_batch(outputs) + return {'loss': accum_loss, 'outputs': outputs} diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index ade3cf456..f5e4929aa 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -53,6 +53,62 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.microbatch_offset += self.microbatch_size return tree_map(partial(to_device, device=get_current_device()), micro_batch) + def recv_forward(self, prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For 1F1B. + + Args: + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if self.stage_manager.is_first_stage(): + input_tensor = None + else: + input_tensor = self.comm.recv_forward(prev_rank) + + return input_tensor + + def recv_backward(self, next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For 1F1B. + + Args: + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if self.stage_manager.is_last_stage(): + output_tensor_grad = None + else: + output_tensor_grad = self.comm.recv_backward(next_rank) + + return output_tensor_grad + + def send_forward(self, output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + For 1F1B. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.stage_manager.is_last_stage(): + self.comm.send_forward(output_object, next_rank) + + def send_backward(self, input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + For 1F1B. + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not self.stage_manager.is_first_stage(): + self.comm.send_backward(input_object, prev_rank) + def forward_step(self, model: Module, input_obj: Optional[dict], @@ -171,11 +227,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): # Run warmup forward passes. for i in range(num_warmup_microbatches): - input_obj = self.comm.recv_forward() + input_obj = self.recv_forward() output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - self.comm.send_forward(output_obj) + self.send_forward(output_obj) if not forward_only: input_objs.append(input_obj) @@ -185,7 +241,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: - input_obj = self.comm.recv_forward() + input_obj = self.recv_forward() # Run 1F1B in steady state. for i in range(num_microbatches_remaining): @@ -193,15 +249,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) if forward_only: - self.comm.send_forward(output_obj) + self.send_forward(output_obj) if not last_iteration: - input_obj = self.comm.recv_forward() + input_obj = self.recv_forward() else: # TODO adjust here - self.comm.send_forward(output_obj) - output_obj_grad = self.comm.recv_backward() + self.send_forward(output_obj) + output_obj_grad = self.recv_backward() # Add input_obj and output_obj to end of list. input_objs.append(input_obj) @@ -216,8 +272,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if last_iteration: input_obj = None else: - input_obj = self.comm.recv_forward() - self.comm.send_backward(input_obj_grad) + input_obj = self.recv_forward() + self.send_backward(input_obj_grad) # Run cooldown backward passes. if not forward_only: @@ -225,9 +281,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) - output_obj_grad = self.comm.recv_backward() + output_obj_grad = self.recv_backward() input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) - self.comm.send_backward(input_obj_grad) + self.send_backward(input_obj_grad) if outputs is not None: outputs = merge_batch(outputs) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index fe228e227..6ba7dc629 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -17,28 +17,24 @@ class PipelineStageManager: Attributes: num_stages (int): Number of stages in the pipeline. stage (int): The current stage. - num_virtual_stages (int): Number of virtual stages in the pipeline. - virtual_stage (int): The current virtual stage. """ - def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None: + def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None: self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis - self.num_virtual_stages: Optional[int] = None - self.virtual_stage: Optional[int] = None self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} # init prev and next coord coord = self.pg_mesh.coordinate() - if self.stage > 0: - prev_coord = coord[: self.pipeline_axis] + \ - (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] - self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape) - if self.stage < self.num_stages - 1: - next_coord = coord[: self.pipeline_axis] + \ - (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] - self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape) + # the prev rank of rank0 is the last rank + prev_coord = coord[: self.pipeline_axis] + \ + (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] + self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode='wrap') + # the next rank of the last rank is rank0 + next_coord = coord[: self.pipeline_axis] + \ + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] + self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode='wrap') # init p2p process groups stages = list(range(self.num_stages)) @@ -48,32 +44,28 @@ class PipelineStageManager: ranks_in_group = self.pg_mesh.get_ranks_in_group(group) self.p2p_groups[tuple(ranks_in_group)] = group - def is_first_stage(self, virtual: bool = False) -> bool: - """Is the current stage the first stage. + if is_virtual: + # add the process group of the first rank and the last rank + # only used in interleaved pipeline for now + group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]]) + if self.stage in [stages[0], stages[-1]]: + ranks_in_group = self.pg_mesh.get_ranks_in_group(group) + self.p2p_groups[tuple(ranks_in_group)] = group - Args: - virtual (bool, optional): Whether to consider virtual stages. Defaults to False. + def is_first_stage(self) -> bool: + """Is the current stage the first stage. Returns: bool: Whether the current stage is the first stage. """ - if virtual: - assert self.num_virtual_stages is not None - return self.virtual_stage == 0 return self.stage == 0 - def is_last_stage(self, virtual: bool = False) -> bool: + def is_last_stage(self) -> bool: """Is the current stage the last stage. - Args: - virtual (bool, optional): Whether to consider virtual stages. Defaults to False. - Returns: bool: Whether the current stage is the last stage. """ - if virtual: - assert self.num_virtual_stages is not None - return self.virtual_stage == self.num_virtual_stages - 1 return self.stage == self.num_stages - 1 @property @@ -108,7 +100,6 @@ class PipelineStageManager: Returns: int: Rank of the previous stage. """ - assert not self.is_first_stage(), "Cannot get previous rank in the first stage." return self.prev_rank def get_next_rank(self) -> int: @@ -117,39 +108,8 @@ class PipelineStageManager: Returns: int: Rank of the next stage. """ - assert not self.is_last_stage(), "Cannot get next rank in the last stage." return self.next_rank - def set_num_virtual_stages(self, num_virtual_stages: int) -> None: - """Set the number of virtual stages. - - Args: - num_virtual_stages (int): Number of virtual stages. - """ - self.num_virtual_stages = num_virtual_stages - - def set_virtual_stage(self, virtual_stage: int) -> None: - """Set the virtual stage. - - Args: - virtual_stage (int): Virtual stage. - """ - self.virtual_stage = virtual_stage - - @contextmanager - def switch_virtual_stage(self, virtual_stage: int) -> None: - """A context manager to switch virtual stage. - - Args: - virtual_stage (int): Target virtual stage. - """ - old_stage = self.virtual_stage - try: - self.set_virtual_stage(virtual_stage) - yield - finally: - self.set_virtual_stage(old_stage) - def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup: """Get the p2p process group between two ranks. The order of the two ranks does not matter. diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py new file mode 100644 index 000000000..2ac31c8ca --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -0,0 +1,161 @@ +import copy +from functools import partial +from types import MethodType + +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all + + +class MlpModel(nn.Module): + + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(4, 8) + self.linear2 = nn.Linear(8, 8) + self.linear3 = nn.Linear(8, 8) + self.linear4 = nn.Linear(8, 8) + self.linear5 = nn.Linear(8, 8) + self.linear6 = nn.Linear(8, 8) + self.linear7 = nn.Linear(8, 8) + self.linear8 = nn.Linear(8, 4) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.linear5(x) + x = self.linear6(x) + x = self.linear7(x) + x = self.linear8(x) + return x + + +def pp_linear_fwd(forward, + data: torch.Tensor = None, + input_obj: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, + num_chunks: int = None, + model_chunk_id: int = None): + + if stage_mgr.is_first_stage() and model_chunk_id == 0: + return {'input_obj': forward(data)} + elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1: + return forward(input_obj) + else: + return {'input_obj': forward(input_obj)} + + +@parameterize("num_micro_batches", [4, 8, 12]) +def examine_pp(num_micro_batches): + """ + This test is to examine the correctness of interleaved 1F1B, compared with torch. + Be aware it contains some hardcodes. + """ + world_size = torch.distributed.get_world_size() + local_rank = torch.distributed.get_rank() + seed_all(1453) + + NUM_MICRO_BATCHS = num_micro_batches + BATCH_SIZE = num_micro_batches + NUM_CHUNKS = 2 + + # create model + torch_model = MlpModel().cuda() + + pp_model = copy.deepcopy(torch_model).cuda() + + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + pg_mesh = ProcessGroupMesh(1, world_size, 1) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True) + schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager) + + sharded_model = torch.nn.ModuleList() + for idx, (_, sub_model) in enumerate(pp_model.named_children()): + if idx % (world_size) == local_rank: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, + stage_mgr=stage_manager, + num_chunks=NUM_CHUNKS, + model_chunk_id=len(sharded_model)), sub_model._forward) + sharded_model.append(sub_model.cuda()) + + # create optimizer + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1)) + + # create + seed_all(1453) + if local_rank == 0: + input_list = [torch.rand(BATCH_SIZE, 4).cuda()] + else: + input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] + torch.distributed.all_reduce(input_list[0]) + + criterion = lambda x, y: torch.mean(x) + + # forward and backward + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output, _) + torch_loss.backward() + + pp_ret = schedule.forward_backward_step(sharded_model, + pp_optimizer, + iter(input_list), + criterion, + return_loss=True, + return_outputs=True) + + # check loss + if stage_manager.is_last_stage(): + assert torch.allclose(torch_loss, pp_ret['loss']) + + # check gradients + torch_grad = [] + for torch_p in torch_model.parameters(): + torch_grad.append(torch_p.grad.data) + + for idx, pp_p in enumerate(sharded_model.parameters()): + if idx < 2: + assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) + else: + assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data) + + # step + torch_optimizer.step() + pp_optimizer.step() + + # check updated param + torch_param = [] + for torch_p in torch_model.parameters(): + torch_param.append(torch_p.data) + for idx, pp_p in enumerate(sharded_model.parameters()): + if idx < 2: + assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) + else: + assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + examine_pp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp(): + spawn(run_dist, 4) + + +if __name__ == '__main__': + test_pp() diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index be4591d58..6e0cd1998 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -49,15 +49,6 @@ def check_stage_manager(): next_rank = ranks_in_group[ranks_in_group.index(rank) + 1] assert stage_manager.get_next_rank() == next_rank - # check virtual stage - stage_manager.set_num_virtual_stages(PP_SIZE * 2) - assert stage_manager.num_virtual_stages == PP_SIZE * 2 - stage_manager.set_virtual_stage(stage_manager.stage * 2) - assert stage_manager.virtual_stage == stage_manager.stage * 2 - with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1): - assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1 - assert stage_manager.virtual_stage == stage_manager.stage * 2 - # check p2p groups for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]): if rank in [prev, cur]: