mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
[pipeline,shardformer] Fix p2p efficiency in pipeline, allow skipping loading weight not in weight_map when strict=False
, fix llama flash attention forward, add flop estimation by megatron in llama benchmark (#5017)
* Use p2p * Cannot bidirectonal send p2p * Refactor tensor creation and serialization in P2P communication * Fix llama forward args in flash attention * Add flop estimate from megatron * Support loading weight not in weight_map when strict=False in hybrid_parallel * Use send_forward_recv_backward, etc in 1f1b * Use dataclass for metdata Remove torch.cuda.synchronize() as suggested * Add comment about the torch.cuda.synchronize for potential error * Typo * Update hybrid_parallel_checkpoint_io.py * Update p2p.py * Update one_f_one_b.py * Update p2p.py --------- Co-authored-by: flybird11111 <1829166702@qq.com>
This commit is contained in:
parent
28052a71fb
commit
b2ad0d9e8f
@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
from functools import reduce
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -313,9 +314,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||||
loaded_file = set()
|
loaded_file = set()
|
||||||
|
|
||||||
|
missing_keys = []
|
||||||
|
missing_file_keys = []
|
||||||
|
|
||||||
def _load(name: str):
|
def _load(name: str):
|
||||||
if name not in weight_map:
|
if name not in weight_map:
|
||||||
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
|
missing_file_keys.append(name)
|
||||||
|
return
|
||||||
filename = weight_map[name]
|
filename = weight_map[name]
|
||||||
|
|
||||||
# If this param/buffer has been loaded before, directly return.
|
# If this param/buffer has been loaded before, directly return.
|
||||||
@ -324,7 +329,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
file_path = os.path.join(ckpt_root_path, filename)
|
file_path = os.path.join(ckpt_root_path, filename)
|
||||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||||
missing_keys = []
|
|
||||||
|
|
||||||
load_state_dict_into_model(
|
load_state_dict_into_model(
|
||||||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
||||||
@ -357,6 +361,27 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
if self.verbose and self.coordinator.is_master():
|
if self.verbose and self.coordinator.is_master():
|
||||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||||
|
|
||||||
|
if len(missing_keys) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No weigth is loaded into the model. Please check the checkpoint files and the model structure."
|
||||||
|
)
|
||||||
|
|
||||||
|
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||||
|
remain_keys = remain_keys.union(set(missing_file_keys))
|
||||||
|
if len(remain_keys) > 0:
|
||||||
|
if strict:
|
||||||
|
error_msgs = "Missing key(s) in state_dict: {}. ".format(
|
||||||
|
", ".join('"{}"'.format(k) for k in missing_keys)
|
||||||
|
)
|
||||||
|
raise RuntimeError(
|
||||||
|
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||||
|
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}")
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self,
|
self,
|
||||||
optimizer: OptimizerWrapper,
|
optimizer: OptimizerWrapper,
|
||||||
|
@ -5,9 +5,12 @@ import io
|
|||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.distributed import distributed_c10d as c10d
|
from torch.distributed import distributed_c10d as c10d
|
||||||
@ -45,6 +48,21 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
|||||||
return unpickle
|
return unpickle
|
||||||
|
|
||||||
|
|
||||||
|
def check_for_nccl_backend(group):
|
||||||
|
pg = group or c10d._get_default_group()
|
||||||
|
# Gate PG wrapper check on Gloo availability.
|
||||||
|
if c10d._GLOO_AVAILABLE:
|
||||||
|
# It is not expected for PG to be wrapped many times, but support it just
|
||||||
|
# in case
|
||||||
|
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||||
|
pg = pg.wrapped_pg
|
||||||
|
|
||||||
|
return (
|
||||||
|
c10d.is_nccl_available() and
|
||||||
|
pg.name() == c10d.Backend.NCCL
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _broadcast_object_list(
|
def _broadcast_object_list(
|
||||||
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
||||||
):
|
):
|
||||||
@ -65,7 +83,7 @@ def _broadcast_object_list(
|
|||||||
c10d._warn_not_in_group("broadcast_object_list")
|
c10d._warn_not_in_group("broadcast_object_list")
|
||||||
return
|
return
|
||||||
|
|
||||||
is_nccl_backend = c10d._check_for_nccl_backend(group)
|
is_nccl_backend = check_for_nccl_backend(group)
|
||||||
current_device = None
|
current_device = None
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
@ -113,7 +131,7 @@ def _broadcast_object_list(
|
|||||||
|
|
||||||
if my_rank != src:
|
if my_rank != src:
|
||||||
for i, obj_size in enumerate(object_sizes_tensor):
|
for i, obj_size in enumerate(object_sizes_tensor):
|
||||||
obj_view = object_tensor[offset : offset + obj_size]
|
obj_view = object_tensor[offset: offset + obj_size]
|
||||||
obj_view = obj_view.type(torch.uint8)
|
obj_view = obj_view.type(torch.uint8)
|
||||||
if obj_view.device != torch.device("cpu"):
|
if obj_view.device != torch.device("cpu"):
|
||||||
obj_view = obj_view.cpu()
|
obj_view = obj_view.cpu()
|
||||||
@ -131,6 +149,258 @@ def _broadcast_object_list(
|
|||||||
object_list[i] = unpickle_object
|
object_list[i] = unpickle_object
|
||||||
|
|
||||||
|
|
||||||
|
def check_device(group):
|
||||||
|
is_nccl_backend = check_for_nccl_backend(group)
|
||||||
|
current_device = None
|
||||||
|
|
||||||
|
current_device = torch.device("cpu")
|
||||||
|
if is_nccl_backend:
|
||||||
|
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||||
|
return current_device, is_nccl_backend
|
||||||
|
|
||||||
|
|
||||||
|
TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad'])
|
||||||
|
|
||||||
|
|
||||||
|
class P2PDataType(Enum):
|
||||||
|
serialization = 0
|
||||||
|
tensor = 1
|
||||||
|
list = 2
|
||||||
|
dict = 3
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class P2PMetadata:
|
||||||
|
data_type: P2PDataType
|
||||||
|
content: Union[List[TensorMetadata], TensorMetadata, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group):
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
obj = obj.contiguous()
|
||||||
|
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
|
||||||
|
ops_queue.append(op_to_add)
|
||||||
|
else:
|
||||||
|
for tensor_to_comm in obj:
|
||||||
|
tensor_to_comm = tensor_to_comm.contiguous()
|
||||||
|
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group)
|
||||||
|
ops_queue.append(op_to_add)
|
||||||
|
|
||||||
|
|
||||||
|
def create_recv_buffer(p2p_metadata: P2PMetadata, current_device):
|
||||||
|
if p2p_metadata.data_type == P2PDataType.tensor:
|
||||||
|
metadata = p2p_metadata.content
|
||||||
|
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
|
||||||
|
return tensor_recv
|
||||||
|
elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict):
|
||||||
|
buffer_recv = []
|
||||||
|
for metadata in p2p_metadata.content:
|
||||||
|
tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
|
||||||
|
buffer_recv.append(tensor_recv)
|
||||||
|
return buffer_recv
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device):
|
||||||
|
buffer_recv = None
|
||||||
|
if recv_tensor_metadata is not None:
|
||||||
|
buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
|
||||||
|
|
||||||
|
ops = []
|
||||||
|
|
||||||
|
if send_dst is not None:
|
||||||
|
filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
||||||
|
|
||||||
|
if recv_src is not None:
|
||||||
|
assert buffer_recv is not None
|
||||||
|
filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
||||||
|
|
||||||
|
if len(ops) > 0:
|
||||||
|
reqs = dist.batch_isend_irecv(ops)
|
||||||
|
for req in reqs:
|
||||||
|
req.wait()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Remove synchronization according to Pytorch's documentation
|
||||||
|
# However, the Megatron-LM does synchronization here
|
||||||
|
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
|
||||||
|
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return buffer_recv
|
||||||
|
|
||||||
|
|
||||||
|
def _send_recv_serialization_object(
|
||||||
|
object: Any,
|
||||||
|
send_dst: Optional[int], recv_src: Optional[int],
|
||||||
|
send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup],
|
||||||
|
current_device,
|
||||||
|
is_nccl_backend):
|
||||||
|
ops = []
|
||||||
|
send_object_tensor = None
|
||||||
|
if object is not None and send_dst is not None:
|
||||||
|
if Version(torch.__version__) >= Version("1.13.0"):
|
||||||
|
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
|
||||||
|
else:
|
||||||
|
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)
|
||||||
|
|
||||||
|
if is_nccl_backend:
|
||||||
|
send_object_size_tensor = send_object_size_tensor.to(current_device)
|
||||||
|
send_object_tensor = send_object_tensor.to(current_device)
|
||||||
|
|
||||||
|
filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
||||||
|
|
||||||
|
recv_object_size_tensor = None
|
||||||
|
if recv_src is not None:
|
||||||
|
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
|
||||||
|
if is_nccl_backend:
|
||||||
|
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
|
||||||
|
filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||||
|
|
||||||
|
if len(ops) > 0:
|
||||||
|
reqs = dist.batch_isend_irecv(ops)
|
||||||
|
for req in reqs:
|
||||||
|
req.wait()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# See the comment in `_batch_send_recv_tensor`
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
|
||||||
|
ops = []
|
||||||
|
|
||||||
|
if send_dst is not None and send_object_tensor is not None:
|
||||||
|
filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
||||||
|
|
||||||
|
recv_object_tensor = None
|
||||||
|
if recv_src is not None and recv_object_size_tensor is not None:
|
||||||
|
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
|
||||||
|
if is_nccl_backend:
|
||||||
|
recv_object_tensor = recv_object_tensor.to(current_device)
|
||||||
|
filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
||||||
|
|
||||||
|
if len(ops) > 0:
|
||||||
|
reqs = dist.batch_isend_irecv(ops)
|
||||||
|
for req in reqs:
|
||||||
|
req.wait()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# See the comment in `_batch_send_recv_tensor`
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
|
||||||
|
if recv_object_tensor is not None and recv_object_size_tensor is not None:
|
||||||
|
recv_object_tensor = recv_object_tensor.type(torch.uint8)
|
||||||
|
if recv_object_tensor.device != torch.device("cpu"):
|
||||||
|
recv_object_tensor = recv_object_tensor.cpu()
|
||||||
|
|
||||||
|
unpickle_object = _cuda_safe_tensor_to_object(
|
||||||
|
recv_object_tensor, recv_object_size_tensor.item())
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(unpickle_object, torch.Tensor)
|
||||||
|
and unpickle_object.device.index != torch.cuda.current_device()
|
||||||
|
):
|
||||||
|
unpickle_object = unpickle_object.cuda()
|
||||||
|
|
||||||
|
return unpickle_object
|
||||||
|
|
||||||
|
|
||||||
|
def _check_if_fast_send_available(object):
|
||||||
|
if type(object) is torch.Tensor:
|
||||||
|
return True
|
||||||
|
elif type(object) is list:
|
||||||
|
is_list_of_tensor = all([type(v) is torch.Tensor for v in object])
|
||||||
|
return is_list_of_tensor
|
||||||
|
elif type(object) is dict:
|
||||||
|
is_dict_of_tensor = all([type(k) is str and type(
|
||||||
|
v) is torch.Tensor for k, v in object.items()])
|
||||||
|
|
||||||
|
return is_dict_of_tensor
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _communicate(
|
||||||
|
object,
|
||||||
|
send_dst: Optional[int],
|
||||||
|
recv_src: Optional[int],
|
||||||
|
send_group: Optional[ProcessGroup] = None,
|
||||||
|
recv_group: Optional[ProcessGroup] = None,
|
||||||
|
) -> Any:
|
||||||
|
if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group):
|
||||||
|
c10d._warn_not_in_group("_communicate")
|
||||||
|
return
|
||||||
|
|
||||||
|
current_send_device, is_send_nccl_backend = check_device(send_group)
|
||||||
|
current_recv_device, is_recv_nccl_backend = check_device(recv_group)
|
||||||
|
|
||||||
|
is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
|
||||||
|
|
||||||
|
assert current_send_device == current_recv_device
|
||||||
|
current_device = current_send_device
|
||||||
|
|
||||||
|
assert (send_dst is not None) or (recv_src is not None)
|
||||||
|
|
||||||
|
can_fast_send = False
|
||||||
|
send_metadata = None
|
||||||
|
if send_dst is not None:
|
||||||
|
can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
|
||||||
|
if not can_fast_send:
|
||||||
|
send_metadata = P2PMetadata(P2PDataType.serialization, object)
|
||||||
|
else:
|
||||||
|
if type(object) is torch.Tensor:
|
||||||
|
data_type = P2PDataType.tensor
|
||||||
|
content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
|
||||||
|
elif type(object) is list:
|
||||||
|
data_type = P2PDataType.list
|
||||||
|
content = []
|
||||||
|
for v in object:
|
||||||
|
content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad))
|
||||||
|
elif type(object) is dict:
|
||||||
|
data_type = P2PDataType.dict
|
||||||
|
content = []
|
||||||
|
for k, v in object.items():
|
||||||
|
content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad))
|
||||||
|
else:
|
||||||
|
raise ValueError('Cannot send object of type {}'.format(type(object)))
|
||||||
|
send_metadata = P2PMetadata(data_type, content)
|
||||||
|
|
||||||
|
recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend)
|
||||||
|
if recv_metadata is not None:
|
||||||
|
assert type(recv_metadata) is P2PMetadata
|
||||||
|
if recv_metadata.data_type == P2PDataType.serialization:
|
||||||
|
return recv_metadata.content
|
||||||
|
if not can_fast_send and send_dst is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
send_tensor_list = None
|
||||||
|
if type(object) is torch.Tensor:
|
||||||
|
send_tensor_list = object
|
||||||
|
elif type(object) is list:
|
||||||
|
send_tensor_list = object
|
||||||
|
elif type(object) is dict:
|
||||||
|
send_tensor_list = list(object.values())
|
||||||
|
|
||||||
|
recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device)
|
||||||
|
|
||||||
|
if recv_metadata is not None:
|
||||||
|
assert recv_buffer is not None
|
||||||
|
if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]:
|
||||||
|
return recv_buffer
|
||||||
|
elif recv_metadata.data_type == P2PDataType.dict:
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in zip(
|
||||||
|
[m.key for m in recv_metadata.content],
|
||||||
|
recv_buffer,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown data type {}'.format(recv_metadata.data_type))
|
||||||
|
|
||||||
|
|
||||||
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
||||||
"""send anything to dst rank
|
"""send anything to dst rank
|
||||||
|
|
||||||
@ -141,8 +411,7 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# then broadcast safely
|
_communicate(object, send_dst=dst, recv_src=None, send_group=group)
|
||||||
_broadcast_object_list([object], src, group)
|
|
||||||
|
|
||||||
|
|
||||||
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
||||||
@ -154,10 +423,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
|
|||||||
Returns:
|
Returns:
|
||||||
Any: Object received from src.
|
Any: Object received from src.
|
||||||
"""
|
"""
|
||||||
object_list = [None]
|
return _communicate(None, send_dst=None, recv_src=src, recv_group=group)
|
||||||
_broadcast_object_list(object_list, src, group)
|
|
||||||
|
|
||||||
return object_list[0]
|
|
||||||
|
|
||||||
|
|
||||||
def _p2p_comm(
|
def _p2p_comm(
|
||||||
@ -302,6 +568,64 @@ class PipelineP2PCommunication:
|
|||||||
cur_rank = self.stage_manager.get_rank()
|
cur_rank = self.stage_manager.get_rank()
|
||||||
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
||||||
|
|
||||||
|
def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any:
|
||||||
|
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_object (Any): Object to be sent.
|
||||||
|
next_rank (int, optional): The rank of the sender and recipient of the tensor
|
||||||
|
"""
|
||||||
|
if next_rank is None:
|
||||||
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
|
|
||||||
|
cur_rank = self.stage_manager.get_rank()
|
||||||
|
group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||||
|
return _communicate(
|
||||||
|
input_object, next_rank, next_rank,
|
||||||
|
send_group=group, recv_group=group,
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any:
|
||||||
|
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_object (Any): Object to be sent.
|
||||||
|
prev_rank (int, optional): The rank of the sender and recipient of the tensor
|
||||||
|
"""
|
||||||
|
if prev_rank is None:
|
||||||
|
prev_rank = self.stage_manager.get_prev_rank()
|
||||||
|
|
||||||
|
cur_rank = self.stage_manager.get_rank()
|
||||||
|
group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||||
|
return _communicate(
|
||||||
|
input_object, prev_rank, prev_rank,
|
||||||
|
send_group=group, recv_group=group,
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
||||||
|
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_object (Any): Object to be sent.
|
||||||
|
prev_rank (int, optional): The rank of the sender of the tensor
|
||||||
|
next_rank (int, optional): The rank of the recipient of the tensor
|
||||||
|
"""
|
||||||
|
if prev_rank is None:
|
||||||
|
prev_rank = self.stage_manager.get_prev_rank()
|
||||||
|
if next_rank is None:
|
||||||
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
|
|
||||||
|
cur_rank = self.stage_manager.get_rank()
|
||||||
|
recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
|
||||||
|
send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
|
||||||
|
return _communicate(
|
||||||
|
input_object,
|
||||||
|
send_dst=next_rank,
|
||||||
|
recv_src=prev_rank,
|
||||||
|
send_group=send_group,
|
||||||
|
recv_group=recv_group,
|
||||||
|
)
|
||||||
|
|
||||||
def p2p_communicate(
|
def p2p_communicate(
|
||||||
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
|
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -127,6 +127,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||||||
if not self.stage_manager.is_last_stage():
|
if not self.stage_manager.is_last_stage():
|
||||||
self.comm.send_forward(output_object, next_rank)
|
self.comm.send_forward(output_object, next_rank)
|
||||||
|
|
||||||
|
def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
|
||||||
|
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||||
|
For 1F1B.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_object (Any): Object to be sent.
|
||||||
|
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||||
|
"""
|
||||||
|
if not self.stage_manager.is_last_stage():
|
||||||
|
return self.comm.send_forward_recv_backward(output_object, next_rank)
|
||||||
|
|
||||||
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
|
||||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||||
For 1F1B.
|
For 1F1B.
|
||||||
@ -138,6 +149,33 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||||||
if not self.stage_manager.is_first_stage():
|
if not self.stage_manager.is_first_stage():
|
||||||
self.comm.send_backward(input_object, prev_rank)
|
self.comm.send_backward(input_object, prev_rank)
|
||||||
|
|
||||||
|
def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
|
||||||
|
"""Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
|
||||||
|
For 1F1B.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_object (Any): Object to be sent.
|
||||||
|
prev_rank (int, optional): The rank of the recipient of the tensor.
|
||||||
|
"""
|
||||||
|
if not self.stage_manager.is_first_stage():
|
||||||
|
return self.comm.send_backward_recv_forward(output_object, prev_rank)
|
||||||
|
|
||||||
|
def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
|
||||||
|
"""Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
|
||||||
|
For 1F1B.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_object (Any): Object to be sent.
|
||||||
|
prev_rank (int, optional): The previous rank of the recipient of the tensor.
|
||||||
|
next_rank (int, optional): The next rank of the recipient of the tensor.
|
||||||
|
"""
|
||||||
|
if self.stage_manager.is_first_stage():
|
||||||
|
return self.comm.send_forward(input_object, next_rank)
|
||||||
|
elif self.stage_manager.is_last_stage():
|
||||||
|
return self.comm.recv_forward(prev_rank)
|
||||||
|
else:
|
||||||
|
return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
|
||||||
|
|
||||||
def forward_step(
|
def forward_step(
|
||||||
self,
|
self,
|
||||||
model: Module,
|
model: Module,
|
||||||
@ -291,7 +329,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||||||
|
|
||||||
if not last_iteration:
|
if not last_iteration:
|
||||||
input_obj = self.recv_forward()
|
input_obj = self.recv_forward()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# TODO adjust here
|
# TODO adjust here
|
||||||
self.send_forward(output_obj)
|
self.send_forward(output_obj)
|
||||||
|
@ -413,6 +413,7 @@ def get_llama_flash_attention_forward():
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||||
|
@ -183,7 +183,11 @@ def main():
|
|||||||
model_numel = get_model_numel(model)
|
model_numel = get_model_numel(model)
|
||||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||||
performance_evaluator = PerformanceEvaluator(
|
performance_evaluator = PerformanceEvaluator(
|
||||||
model_numel, args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size
|
model_numel,
|
||||||
|
model.config.num_hidden_layers,
|
||||||
|
model.config.hidden_size,
|
||||||
|
model.config.vocab_size,
|
||||||
|
args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = HybridAdam(model.parameters())
|
optimizer = HybridAdam(model.parameters())
|
||||||
|
@ -58,6 +58,9 @@ class PerformanceEvaluator:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_numel: int,
|
model_numel: int,
|
||||||
|
num_layers: int,
|
||||||
|
hidden_size: int,
|
||||||
|
vocab_size: int,
|
||||||
enable_grad_checkpoint: bool = False,
|
enable_grad_checkpoint: bool = False,
|
||||||
ignore_steps: int = 0,
|
ignore_steps: int = 0,
|
||||||
dp_world_size: Optional[int] = None,
|
dp_world_size: Optional[int] = None,
|
||||||
@ -65,12 +68,16 @@ class PerformanceEvaluator:
|
|||||||
self.model_numel = model_numel
|
self.model_numel = model_numel
|
||||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
self.enable_grad_checkpoint = enable_grad_checkpoint
|
||||||
self.ignore_steps = ignore_steps
|
self.ignore_steps = ignore_steps
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
self.dp_world_size = dp_world_size or self.coordinator.world_size
|
self.dp_world_size = dp_world_size or self.coordinator.world_size
|
||||||
self.disable: bool = False
|
self.disable: bool = False
|
||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
self.num_samples: int = 0
|
self.num_samples: int = 0
|
||||||
|
self.flop_megatron = 0
|
||||||
self.flop: int = 0
|
self.flop: int = 0
|
||||||
|
|
||||||
def on_step_start(self, step: int) -> None:
|
def on_step_start(self, step: int) -> None:
|
||||||
@ -89,17 +96,20 @@ class PerformanceEvaluator:
|
|||||||
batch_size, seq_len = input_ids.shape
|
batch_size, seq_len = input_ids.shape
|
||||||
|
|
||||||
self.num_samples += batch_size
|
self.num_samples += batch_size
|
||||||
|
checkpoint_activations_factor = (3 + int(self.enable_grad_checkpoint))
|
||||||
|
self.flop_megatron += (24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2)) * (1. + (seq_len / (6. * self.hidden_size)) + (self.vocab_size / (16. * self.num_layers * self.hidden_size)))
|
||||||
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
|
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||||
|
|
||||||
def on_fit_end(self) -> None:
|
def on_fit_end(self) -> None:
|
||||||
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
|
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
|
||||||
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
|
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
|
||||||
mp_world_size = self.coordinator.world_size // self.dp_world_size
|
mp_world_size = self.coordinator.world_size // self.dp_world_size
|
||||||
|
avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
||||||
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
|
||||||
self.coordinator.print_on_master(
|
self.coordinator.print_on_master(
|
||||||
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, "
|
||||||
f"avg_throughput: {avg_throughput}"
|
f"avg_throughput: {avg_throughput}"
|
||||||
)
|
)
|
||||||
self.coordinator.print_on_master(
|
self.coordinator.print_on_master(
|
||||||
f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
|
f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user