mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
support fp8 communication in pipeline parallelism
This commit is contained in:
@@ -104,4 +104,71 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
||||
for i in range(world_size):
|
||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||
tensor_out = torch.cat(tensor_list, dim=0)
|
||||
tensor.data = tensor_out.view(input_shape).to(input_type)
|
||||
tensor.data = tensor_out.view(input_shape).to(input_type)
|
||||
|
||||
|
||||
|
||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
"""
|
||||
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
||||
The activations tensor is indexed by 'hidden_states' in the inp dict.
|
||||
After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved.
|
||||
Metadata such as fp8_scale is saved into inp dict for communication.
|
||||
"""
|
||||
if inp is None:
|
||||
return
|
||||
# In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted.
|
||||
if type(inp) == torch.Tensor:
|
||||
return
|
||||
|
||||
assert 'hidden_states' in inp, 'required by pipeline parallelism.'
|
||||
inp_tensor = inp["hidden_states"]
|
||||
|
||||
min_val, max_val = inp_tensor.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs())
|
||||
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
if amax > finfo.max:
|
||||
fp8_type = torch.float8_e5m2
|
||||
fp8_view_type = torch.float16
|
||||
else:
|
||||
fp8_type = torch.float8_e4m3fn
|
||||
fp8_view_type = torch.bfloat16
|
||||
|
||||
finfo = torch.finfo(fp8_type)
|
||||
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
|
||||
q_tensor = (inp_tensor.data.float() * scale)
|
||||
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
|
||||
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
|
||||
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
|
||||
|
||||
inp["fp8_scale"] = scale.float().reciprocal()
|
||||
|
||||
|
||||
|
||||
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
"""
|
||||
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
|
||||
del_metadata = False is useful when this function is called before p2p communication.
|
||||
"""
|
||||
if inp is None:
|
||||
return
|
||||
if type(inp) == torch.Tensor:
|
||||
return
|
||||
|
||||
assert 'hidden_states' in inp, 'required by pipeline parallelism.'
|
||||
inp_tensor = inp["hidden_states"]
|
||||
scale = inp["fp8_scale"]
|
||||
|
||||
fp8_view_type = inp_tensor.dtype
|
||||
if fp8_view_type == torch.float16:
|
||||
fp8_type = torch.float8_e5m2
|
||||
elif fp8_view_type == torch.bfloat16:
|
||||
fp8_type = torch.float8_e4m3fn
|
||||
else:
|
||||
raise TypeError("Only float16, bfloat16 are implemented.")
|
||||
|
||||
inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale
|
||||
|
||||
if del_metadata:
|
||||
del inp["fp8_scale"]
|
Reference in New Issue
Block a user