[Feature] llama shardformer fp8 support (#5938)

* add llama shardformer fp8

* Llama Shardformer Parity

* fix typo

* fix all reduce

* fix pytest failure

* fix reduce op and move function to fp8.py

* fix typo
This commit is contained in:
Guangyao Zhang
2024-08-05 10:05:47 +08:00
committed by GitHub
parent c297e21bea
commit 53cb9606bd
11 changed files with 453 additions and 98 deletions

View File

@@ -0,0 +1,39 @@
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize("shape", [(16, 8, 4)])
@parameterize("scatter_dim", [0, 1, 2])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def check_4gpu(shape, scatter_dim, dtype, fp8_format):
world_size = dist.get_world_size()
input_tensor = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
input_tensor_list = list(torch.chunk(input_tensor, world_size, scatter_dim))
input_tensor_list = [x.contiguous() for x in input_tensor_list]
output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list]
output_tensor_list = [torch.empty_like(x) for x in input_tensor_list]
all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format)
dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group())
assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1)
def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()
@rerun_if_address_is_in_use()
def test_all_to_all():
spawn(run_dist, 4)
if __name__ == "__main__":
test_all_to_all()

View File

@@ -0,0 +1,37 @@
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_single_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
dist.all_to_all_single
@parameterize("shape", [(4), (8, 7), (4, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def check_4gpu(shape, dtype, fp8_format):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
output = torch.empty_like(x)
output_fp8 = torch.empty_like(x)
all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), fp8_format=fp8_format)
dist.all_to_all_single(output, x, group=_get_default_group())
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()
@rerun_if_address_is_in_use()
def test_all_to_all_single():
spawn(run_dist, 4)
if __name__ == "__main__":
test_all_to_all_single()

View File

@@ -32,9 +32,9 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_all_gather():
def test_all_gather_flat():
spawn(run_dist, 4)
if __name__ == "__main__":
test_all_gather()
test_all_gather_flat()

View File

@@ -0,0 +1,48 @@
import torch
import torch.distributed as dist
from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_reduce_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize(
"shape",
[
(3, 7),
(4, 7),
(7, 4),
(8, 9),
(3),
(7,),
(8,),
],
)
@parameterize("dtype", [torch.float16, torch.bfloat16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def check_4gpu(shape, dtype, fp8_format):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
x_fp8 = x.clone()
dist.all_reduce(x)
all_reduce_fp8(x_fp8, fp8_format=fp8_format)
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
dist.all_reduce(x, op=dist.ReduceOp.AVG)
all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format)
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()
@rerun_if_address_is_in_use()
def test_all_reduce():
spawn(run_dist, 4)
if __name__ == "__main__":
test_all_reduce()

View File

@@ -0,0 +1,48 @@
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import gather_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize(
"shape",
[
(3, 7),
(2, 1),
(1, 2),
(2, 2),
(4, 2),
(5,),
(4,),
(2,),
],
)
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def check_4gpu(shape, dtype, fp8_format):
world_size = dist.get_world_size()
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
output_list = [torch.empty_like(x) for _ in range(world_size)]
output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)]
gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format)
dist.all_gather(output_list, x, group=_get_default_group())
assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1)
def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()
@rerun_if_address_is_in_use()
def test_all_gather():
spawn(run_dist, 4)
if __name__ == "__main__":
test_all_gather()

View File

@@ -0,0 +1,38 @@
import torch
from torch.distributed import reduce_scatter
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import reduce_scatter_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize("shape", [(16, 8, 4)])
@parameterize("scatter_dim", [0, 1, 2])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def check_4gpu(shape, scatter_dim, dtype, fp8_format):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
input_list = list(torch.chunk(x, dim=scatter_dim, chunks=4))
input_list = [t.contiguous() for t in input_list]
output_origin = torch.empty_like(input_list[0])
output_fp8 = torch.empty_like(input_list[0])
reduce_scatter(output_origin, input_list, group=_get_default_group())
reduce_scatter_fp8(output_fp8, input_list, group=_get_default_group(), fp8_format=fp8_format)
assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1)
def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()
@rerun_if_address_is_in_use()
def test_reduce_scatter():
spawn(run_dist, 4)
if __name__ == "__main__":
test_reduce_scatter()