support fp8 communication in pipeline parallelism

This commit is contained in:
BurkeHulk
2024-07-12 15:25:25 +08:00
parent 1e1959467e
commit e88190184a
4 changed files with 126 additions and 1 deletions

View File

@@ -104,4 +104,71 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
tensor_out = torch.cat(tensor_list, dim=0)
tensor.data = tensor_out.view(input_shape).to(input_type)
tensor.data = tensor_out.view(input_shape).to(input_type)
def cast_to_fp8_pipeline(inp: Any) -> None:
"""
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
The activations tensor is indexed by 'hidden_states' in the inp dict.
After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved.
Metadata such as fp8_scale is saved into inp dict for communication.
"""
if inp is None:
return
# In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted.
if type(inp) == torch.Tensor:
return
assert 'hidden_states' in inp, 'required by pipeline parallelism.'
inp_tensor = inp["hidden_states"]
min_val, max_val = inp_tensor.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs())
finfo = torch.finfo(torch.float8_e4m3fn)
if amax > finfo.max:
fp8_type = torch.float8_e5m2
fp8_view_type = torch.float16
else:
fp8_type = torch.float8_e4m3fn
fp8_view_type = torch.bfloat16
finfo = torch.finfo(fp8_type)
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
q_tensor = (inp_tensor.data.float() * scale)
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
inp["fp8_scale"] = scale.float().reciprocal()
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
"""
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
del_metadata = False is useful when this function is called before p2p communication.
"""
if inp is None:
return
if type(inp) == torch.Tensor:
return
assert 'hidden_states' in inp, 'required by pipeline parallelism.'
inp_tensor = inp["hidden_states"]
scale = inp["fp8_scale"]
fp8_view_type = inp_tensor.dtype
if fp8_view_type == torch.float16:
fp8_type = torch.float8_e5m2
elif fp8_view_type == torch.bfloat16:
fp8_type = torch.float8_e4m3fn
else:
raise TypeError("Only float16, bfloat16 are implemented.")
inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale
if del_metadata:
del inp["fp8_scale"]