mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -107,7 +107,6 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
||||
tensor.data = tensor_out.view(input_shape).to(input_type)
|
||||
|
||||
|
||||
|
||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
"""
|
||||
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
||||
@@ -121,7 +120,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
if type(inp) == torch.Tensor:
|
||||
return
|
||||
|
||||
assert 'hidden_states' in inp, 'required by pipeline parallelism.'
|
||||
assert "hidden_states" in inp, "required by pipeline parallelism."
|
||||
inp_tensor = inp["hidden_states"]
|
||||
|
||||
min_val, max_val = inp_tensor.aminmax()
|
||||
@@ -137,7 +136,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
|
||||
finfo = torch.finfo(fp8_type)
|
||||
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
|
||||
q_tensor = (inp_tensor.data.float() * scale)
|
||||
q_tensor = inp_tensor.data.float() * scale
|
||||
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
|
||||
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
|
||||
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
|
||||
@@ -145,7 +144,6 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
inp["fp8_scale"] = scale.float().reciprocal()
|
||||
|
||||
|
||||
|
||||
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
"""
|
||||
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
|
||||
@@ -156,7 +154,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
if type(inp) == torch.Tensor:
|
||||
return
|
||||
|
||||
assert 'hidden_states' in inp, 'required by pipeline parallelism.'
|
||||
assert "hidden_states" in inp, "required by pipeline parallelism."
|
||||
inp_tensor = inp["hidden_states"]
|
||||
scale = inp["fp8_scale"]
|
||||
|
||||
@@ -171,4 +169,4 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale
|
||||
|
||||
if del_metadata:
|
||||
del inp["fp8_scale"]
|
||||
del inp["fp8_scale"]
|
||||
|
Reference in New Issue
Block a user