mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
[shardformer] refactored the shardformer layer structure (#4053)
This commit is contained in:
@@ -4,7 +4,7 @@ import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy
|
||||
from colossalai.shardformer.layer import cross_entropy_1d
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
@@ -25,7 +25,7 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index):
|
||||
org_loss = F.cross_entropy(org_pred, org_labels)
|
||||
|
||||
dist_pred = pred.chunk(world_size, -1)[rank]
|
||||
dist_loss = applyDistCrossEntropy(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index)
|
||||
dist_loss = cross_entropy_1d(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index)
|
||||
|
||||
assert torch.allclose(org_loss, dist_loss,
|
||||
atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
|
@@ -3,7 +3,7 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
from colossalai.shardformer.layer import Dropout1D
|
||||
from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
@@ -4,7 +4,7 @@ import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.layers import Embedding1D
|
||||
from colossalai.shardformer.layer import Embedding1D
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
@@ -4,7 +4,7 @@ import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
@@ -4,7 +4,7 @@ import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.layers import VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer import VocabParallelEmbedding1D
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
@@ -1,51 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
|
||||
|
||||
def check_dropout(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
|
||||
|
||||
# prepare data
|
||||
input = torch.randn(5, 4).to('cuda')
|
||||
dropout = Dropout1D(p=0.4).to('cuda')
|
||||
output_list = []
|
||||
# compare the dropout pattern in each device
|
||||
for i in range(2):
|
||||
output = dropout(input)
|
||||
output_list.append(output)
|
||||
dist_output_list = [torch.zeros(*output.shape).to('cuda') for _ in range(world_size)]
|
||||
torch.distributed.all_gather(dist_output_list, output)
|
||||
for j in range(world_size):
|
||||
for k in range(world_size):
|
||||
if j != k:
|
||||
mask = torch.eq(dist_output_list[j], 0.0) == torch.eq(dist_output_list[k], 0.0)
|
||||
assert torch.all(
|
||||
mask
|
||||
) == False, f"The dropout pattern in each device is not unique\n{dist_output_list[j]}\n{dist_output_list[k]}"
|
||||
# compare the dropout pattern in loacl device
|
||||
for i in range(len(output_list)):
|
||||
for j in range(len(output_list)):
|
||||
if i != j:
|
||||
mask = torch.eq(output_list[i], 0.0) == torch.eq(output_list[j], 0.0)
|
||||
assert torch.all(
|
||||
mask
|
||||
) == False, f"The dropout pattern in one device is not unique\n{output_list[i]}\n{output_list[j]}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dropout():
|
||||
spawn(check_dropout, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dropout()
|
@@ -1,78 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.policies.basepolicy import Col_Layer, Layer, Row_Layer
|
||||
from colossalai.shardformer.shard.shard_config import ShardConfig
|
||||
from colossalai.shardformer.shard.slicer import Slicer
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
|
||||
|
||||
def check_slicer(rank, world_size, port, in_feature, out_feature):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
|
||||
# initialize slicer
|
||||
shardconfig = ShardConfig(rank=rank, world_size=world_size)
|
||||
slicer = Slicer(shardconfig)
|
||||
# initialize test data
|
||||
weight = torch.randn(in_feature, out_feature)
|
||||
bias = torch.randn(out_feature)
|
||||
policy_layer_cls_list = [Layer, Col_Layer, Row_Layer]
|
||||
n_cast_list = [None, 2, 3, 4]
|
||||
# weight and bias
|
||||
for n_cast in n_cast_list:
|
||||
sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Layer, n_cast=n_cast)
|
||||
expected_sliced_weight = weight
|
||||
expected_sliced_bias = bias
|
||||
assert torch.equal(
|
||||
sliced_weight, expected_sliced_weight
|
||||
), f"In Layer case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||
assert torch.equal(
|
||||
sliced_bias, expected_sliced_bias
|
||||
), f"In Layer case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||
|
||||
sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Col_Layer, n_cast=n_cast)
|
||||
if (n_cast is None):
|
||||
expected_sliced_weight = weight.chunk(world_size, dim=0)[rank]
|
||||
expected_sliced_bias = bias.chunk(world_size)[rank]
|
||||
else:
|
||||
chunks = weight.chunk(world_size * n_cast, dim=0)
|
||||
expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=0)
|
||||
chunks = bias.chunk(world_size * n_cast, dim=0)
|
||||
expected_sliced_bias = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)])
|
||||
assert torch.equal(
|
||||
sliced_weight, expected_sliced_weight
|
||||
), f"In Col_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||
assert torch.equal(
|
||||
sliced_bias, expected_sliced_bias
|
||||
), f"In Col_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_bias}\nexpected:{expected_sliced_bias}"
|
||||
|
||||
sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Row_Layer, n_cast=n_cast)
|
||||
if (n_cast is None):
|
||||
expected_sliced_weight = weight.chunk(world_size, dim=1)[rank]
|
||||
expected_sliced_bias = bias
|
||||
else:
|
||||
chunks = weight.chunk(world_size * n_cast, dim=1)
|
||||
expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=1)
|
||||
expected_sliced_bias = bias
|
||||
assert torch.equal(
|
||||
sliced_weight, expected_sliced_weight
|
||||
), f"In Row_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||
assert torch.equal(
|
||||
sliced_bias, expected_sliced_bias
|
||||
), f"In Row_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_slicer():
|
||||
args = dict(in_feature=24, out_feature=48)
|
||||
spawn(check_slicer, nprocs=2, in_feature=args['in_feature'], out_feature=args['out_feature'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_slicer()
|
Reference in New Issue
Block a user