[tensor] refine linear and add gather for laynorm (#893)

* refine linear and add function to ColoTensor

* add gather for layernorm

* polish

* polish
This commit is contained in:
Ziyue Jiang
2022-04-28 10:55:40 +08:00
committed by GitHub
parent 26c49639d8
commit cb182da7c5
7 changed files with 225 additions and 123 deletions

View File

@@ -145,7 +145,7 @@ def run_linear_tp1d_row_test():
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_linear_tp1d_row_test()
#run_linear_tp1d_row_test()
run_linear_tp1d_col_test()
@pytest.mark.dist

View File

@@ -26,6 +26,77 @@ def set_seed(seed):
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def run_1d_col_tp():
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
set_seed(1)
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
parallel_action_list_row = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)
]
spec_row = TensorSpec(parallel_action_list_row)
parallel_action_list_col = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)
]
spec_col = TensorSpec(parallel_action_list_col)
set_seed(1)
if rank == 0:
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
# A naive way to set spec for all weights in Linear
for name, p in named_params_with_colotensor(model):
if not isinstance(p, ColoTensor):
continue
if 'proj1' in name and ('weight' in name or 'bias' in name):
p.set_spec(spec_col)
if 'proj2' in name and 'weight' in name:
p.set_spec(spec_row)
model = model.cuda()
for i, (data, label) in enumerate(train_dataloader):
data = data.to(get_current_device())
label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Bcast rank0 data to all processes
if criterion:
output = model(data)
loss = criterion(output, label)
else:
output = model(data, label)
loss = output
# For reference
if rank == 0:
if criterion:
output_torch = model_torch(data)
loss_torch = criterion(output_torch, label)
else:
output_torch = model_torch(data, label)
loss_torch = output_torch
if rank == 0:
# print(loss.torch_tensor().item())
# print('loss torch', loss_torch.item())
assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2)
loss.backward()
if rank == 0:
loss_torch.backward()
if i > 5:
break
def run_1d_row_tp():
# A simple net with two stacked nn.Linear