mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
fix typo
This commit is contained in:
@@ -14,7 +14,7 @@ def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tens
|
||||
Returns:
|
||||
Tuples: A tuple (fp8_tensor, scale)
|
||||
"""
|
||||
if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]:
|
||||
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
return inp
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
|
||||
|
Reference in New Issue
Block a user