mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 12:12:46 +00:00
[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)
* Refractor and cleanup using common helper funcs. Tests passed * Update comments * Fix relative import * Fix param fetching bug
This commit is contained in:
parent
5c09d726a6
commit
ec73f1b5e2
@ -384,7 +384,7 @@ class Linear1D_Row(ParallelModule):
|
|||||||
out_features (int): size of each output sample.
|
out_features (int): size of each output sample.
|
||||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
parallel_input (bool): If set to ``True``, it's assumed that the input is already split/copied across each rank, defaults to False.
|
||||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||||
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
|
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
|
||||||
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
|
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
|
||||||
@ -544,14 +544,14 @@ class Linear1D_Row(ParallelModule):
|
|||||||
if self.parallel_input:
|
if self.parallel_input:
|
||||||
assert (
|
assert (
|
||||||
input_.shape[-1] == self.weight.shape[-1]
|
input_.shape[-1] == self.weight.shape[-1]
|
||||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected feature dim of input {}.".format(
|
||||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||||
)
|
)
|
||||||
input_ = input_
|
input_ = input_
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
|
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
|
||||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected feature dim of input {}.".format(
|
||||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
||||||
)
|
)
|
||||||
input_ = split_forward_gather_backward(
|
input_ = split_forward_gather_backward(
|
||||||
|
@ -13,7 +13,7 @@ _HID_DIM = 128
|
|||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=False, dtype=torch.float32):
|
def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True, dtype=torch.float32):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if identity:
|
if identity:
|
||||||
self.fc0 = nn.Identity()
|
self.fc0 = nn.Identity()
|
||||||
@ -30,7 +30,7 @@ class Net(nn.Module):
|
|||||||
class TPNet(nn.Module):
|
class TPNet(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fc0=nn.Linear(_IN_DIM, _IN_DIM),
|
fc0=nn.Identity(),
|
||||||
fc1=nn.Linear(_IN_DIM, _HID_DIM),
|
fc1=nn.Linear(_IN_DIM, _HID_DIM),
|
||||||
fc2=nn.Linear(_HID_DIM, _IN_DIM),
|
fc2=nn.Linear(_HID_DIM, _IN_DIM),
|
||||||
tp_group=None,
|
tp_group=None,
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.shardformer.layer.utils import Randomizer
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
|
from colossalai.tensor.d_tensor import get_layout, get_sharding_spec, is_distributed_tensor
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
|
from colossalai.tensor.d_tensor.sharding_spec import DimSpec
|
||||||
from colossalai.testing import parameterize, spawn
|
from colossalai.testing import parameterize, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_shardformer.test_model._utils import (
|
from tests.test_shardformer.test_model._utils import (
|
||||||
@ -15,6 +18,88 @@ from tests.test_shardformer.test_model._utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def force_assign_grad(p, g_dtype, grad=None):
|
||||||
|
"""Bypass inconsistent grad and param dtype error when assigning grad"""
|
||||||
|
orig_p = p.data
|
||||||
|
p.data = torch.randn_like(p, device=orig_p.device, dtype=g_dtype) if grad == None else grad.clone().to(g_dtype)
|
||||||
|
p.grad = p.data
|
||||||
|
p.data = orig_p
|
||||||
|
|
||||||
|
|
||||||
|
def setup_param_groups(model: nn.Module) -> list:
|
||||||
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
|
|
||||||
|
# setup flatten param groups, sharding spec and shape; (For dist Adafactor and CAME)
|
||||||
|
def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict:
|
||||||
|
flatten_optimizer_grouped_parameters = []
|
||||||
|
sharding_spec = {} # {id(flatten param): get_layout(p).global_shape}
|
||||||
|
param_shape = {} # {id(flatten param): get_sharding_spec(p)}
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
# flatten_p = copy.deepcopy(p).flatten()
|
||||||
|
flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True))
|
||||||
|
flatten_optimizer_grouped_parameters.append(flatten_p)
|
||||||
|
if is_distributed_tensor(p):
|
||||||
|
sharding_spec[id(flatten_p)] = get_sharding_spec(p)
|
||||||
|
param_shape[id(flatten_p)] = get_layout(p).global_shape
|
||||||
|
else:
|
||||||
|
sharding_spec[id(flatten_p)] = None
|
||||||
|
param_shape[id(flatten_p)] = p.shape
|
||||||
|
return flatten_optimizer_grouped_parameters, sharding_spec, param_shape
|
||||||
|
|
||||||
|
|
||||||
|
def set_master_param_to_shard_param(master_param_list) -> dict:
|
||||||
|
master_param_to_shard_param = {id(p): p for p in master_param_list}
|
||||||
|
return master_param_to_shard_param
|
||||||
|
|
||||||
|
|
||||||
|
def set_dist_grad(
|
||||||
|
dist_module: nn.Module,
|
||||||
|
torch_model: nn.Module,
|
||||||
|
g_dtype: torch.dtype,
|
||||||
|
group: dist.ProcessGroup,
|
||||||
|
tp_spec: DimSpec,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set split grads for Tensor Parallel or ZeRO DP.
|
||||||
|
We do not need a separate treatment for ZeRO,
|
||||||
|
as the wrapper takes care of reduce-scattering grads.
|
||||||
|
"""
|
||||||
|
rank = dist.get_rank(group)
|
||||||
|
world_size = dist.get_world_size(group)
|
||||||
|
|
||||||
|
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
|
||||||
|
if torch_p.grad is None:
|
||||||
|
torch_p.grad = torch.zeros_like(torch_p)
|
||||||
|
|
||||||
|
is_distributed = hasattr(p, "dist_layout")
|
||||||
|
if is_distributed:
|
||||||
|
sharding = p.dist_layout.sharding_spec.sharding_sequence
|
||||||
|
split_dim = sharding.index(tp_spec)
|
||||||
|
shape = torch_p.split(world_size, dim=split_dim)[rank].shape
|
||||||
|
|
||||||
|
indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1))
|
||||||
|
# Generate grads only for the correctly split chunk
|
||||||
|
torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype))
|
||||||
|
|
||||||
|
else:
|
||||||
|
shape = torch_p.shape
|
||||||
|
torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype)
|
||||||
|
|
||||||
|
force_assign_grad(p, g_dtype, grad=torch_p.grad)
|
||||||
|
|
||||||
|
|
||||||
def check_optim_states(org_optim, sharded_optim):
|
def check_optim_states(org_optim, sharded_optim):
|
||||||
for group in org_optim.param_groups:
|
for group in org_optim.param_groups:
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
|
@ -8,6 +8,7 @@ from torch.optim import Adam, AdamW
|
|||||||
|
|
||||||
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_optimizer._utils import force_assign_grad, setup_param_groups
|
||||||
|
|
||||||
_ALLOWED_OPTIM_DEVICES = [
|
_ALLOWED_OPTIM_DEVICES = [
|
||||||
(FusedAdam, torch.device("cuda:0")),
|
(FusedAdam, torch.device("cuda:0")),
|
||||||
@ -26,29 +27,11 @@ _ALLOWED_P_G_TYPES = [
|
|||||||
N_STEPS = 3
|
N_STEPS = 3
|
||||||
|
|
||||||
|
|
||||||
def setup_param_groups(bert_model: nn.Module) -> list:
|
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
|
||||||
optimizer_grouped_parameters = [
|
|
||||||
{
|
|
||||||
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": 0.1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
return optimizer_grouped_parameters
|
|
||||||
|
|
||||||
|
|
||||||
def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
|
def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
|
||||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||||
torch_p.grad = torch.rand_like(torch_p)
|
torch_p.grad = torch.rand_like(torch_p)
|
||||||
# avoid inconsistent grad and param dtype error
|
# avoid inconsistent grad and param dtype error
|
||||||
orig_p = p.data
|
force_assign_grad(p, g_dtype, torch_p.grad)
|
||||||
p.data = torch_p.grad.clone().to(g_dtype)
|
|
||||||
p.grad = p.data
|
|
||||||
p.data = orig_p
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES)
|
@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES)
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import copy
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -16,7 +14,6 @@ from colossalai.shardformer.layer.utils import Randomizer
|
|||||||
from colossalai.tensor.d_tensor import (
|
from colossalai.tensor.d_tensor import (
|
||||||
distribute_tensor,
|
distribute_tensor,
|
||||||
get_device_mesh,
|
get_device_mesh,
|
||||||
get_layout,
|
|
||||||
get_sharding_spec,
|
get_sharding_spec,
|
||||||
is_distributed_tensor,
|
is_distributed_tensor,
|
||||||
shard_colwise,
|
shard_colwise,
|
||||||
@ -28,7 +25,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||||||
from colossalai.utils import set_seed
|
from colossalai.utils import set_seed
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_optimizer._utils import check_dist_optim_state, check_dist_param, check_optim_states
|
from tests.test_optimizer._utils import (
|
||||||
|
check_dist_optim_state,
|
||||||
|
check_dist_param,
|
||||||
|
check_optim_states,
|
||||||
|
set_master_param_to_shard_param,
|
||||||
|
setup_param_groups,
|
||||||
|
)
|
||||||
from tests.test_shardformer.test_model._utils import (
|
from tests.test_shardformer.test_model._utils import (
|
||||||
build_model_from_hybrid_plugin,
|
build_model_from_hybrid_plugin,
|
||||||
build_model_from_low_level_zero_plugin,
|
build_model_from_low_level_zero_plugin,
|
||||||
@ -38,10 +41,13 @@ from tests.test_shardformer.test_model._utils import (
|
|||||||
unwrap_model,
|
unwrap_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
HEIGHT = 4
|
IN_DIM = 4
|
||||||
WIDTH = 4
|
HID_DIM = 4
|
||||||
_TP_SPEC = DimSpec([0])
|
_TP_SPEC = DimSpec([0])
|
||||||
|
|
||||||
|
Net, data_gen, *_ = next(iter(model_zoo.get_sub_registry("simple_mlp").values()))
|
||||||
|
TPNet, *_ = next(iter(model_zoo.get_sub_registry("simple_tp_mlp").values()))
|
||||||
|
|
||||||
|
|
||||||
def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32):
|
def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32):
|
||||||
rtol = None
|
rtol = None
|
||||||
@ -59,92 +65,11 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc
|
|||||||
assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
|
assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
# setup param groups; (For zero test optim)
|
|
||||||
def setup_param_groups_zero(model: nn.Module) -> list:
|
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
|
||||||
optimizer_grouped_parameters = [
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": 0.1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
return optimizer_grouped_parameters
|
|
||||||
|
|
||||||
|
|
||||||
# setup param groups; (For base optim)
|
|
||||||
def setup_param_groups(model: nn.Module) -> list:
|
|
||||||
optimizer_grouped_parameters = [p for n, p in model.named_parameters()]
|
|
||||||
return optimizer_grouped_parameters
|
|
||||||
|
|
||||||
|
|
||||||
# setup flatten param groups, sharding spec and shape; (For dist optim)
|
|
||||||
def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict:
|
|
||||||
flatten_optimizer_grouped_parameters = []
|
|
||||||
sharding_spec = {} # {id(flatten param): get_layout(p).global_shape}
|
|
||||||
param_shape = {} # {id(flatten param): get_sharding_spec(p)}
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
# flatten_p = copy.deepcopy(p).flatten()
|
|
||||||
flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True))
|
|
||||||
flatten_optimizer_grouped_parameters.append(flatten_p)
|
|
||||||
if is_distributed_tensor(p):
|
|
||||||
sharding_spec[id(flatten_p)] = get_sharding_spec(p)
|
|
||||||
param_shape[id(flatten_p)] = get_layout(p).global_shape
|
|
||||||
else:
|
|
||||||
sharding_spec[id(flatten_p)] = None
|
|
||||||
param_shape[id(flatten_p)] = p.shape
|
|
||||||
return flatten_optimizer_grouped_parameters, sharding_spec, param_shape
|
|
||||||
|
|
||||||
|
|
||||||
def set_dist_grad(
|
|
||||||
dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Set split grads for Tensor Parallel or ZeRO DP.
|
|
||||||
We do not need a separate treatment for ZeRO,
|
|
||||||
as the wrapper takes care of reduce-scattering grads.
|
|
||||||
"""
|
|
||||||
rank = dist.get_rank(group)
|
|
||||||
world_size = dist.get_world_size(group)
|
|
||||||
|
|
||||||
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
|
|
||||||
if torch_p.grad is None:
|
|
||||||
torch_p.grad = torch.zeros_like(torch_p)
|
|
||||||
|
|
||||||
is_distributed = hasattr(p, "dist_layout")
|
|
||||||
if is_distributed:
|
|
||||||
sharding = p.dist_layout.sharding_spec.sharding_sequence
|
|
||||||
split_dim = sharding.index(_TP_SPEC)
|
|
||||||
shape = torch_p.split(world_size, dim=split_dim)[rank].shape
|
|
||||||
|
|
||||||
indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1))
|
|
||||||
# Generate grads only for the correctly split chunk
|
|
||||||
torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype))
|
|
||||||
|
|
||||||
else:
|
|
||||||
shape = torch_p.shape
|
|
||||||
torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype)
|
|
||||||
|
|
||||||
# avoid inconsistent grad and param dtype error
|
|
||||||
orig_p = p.data
|
|
||||||
p.data = torch_p.grad.clone().to(g_dtype)
|
|
||||||
p.grad = p.data
|
|
||||||
p.data = orig_p
|
|
||||||
|
|
||||||
|
|
||||||
def set_master_param_to_shard_param(master_param_list) -> dict:
|
|
||||||
master_param_to_shard_param = {id(p): p for p in master_param_list}
|
|
||||||
return master_param_to_shard_param
|
|
||||||
|
|
||||||
|
|
||||||
class MlpModel(nn.Module):
|
class MlpModel(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MlpModel, self).__init__()
|
super(MlpModel, self).__init__()
|
||||||
self.linear1 = nn.Linear(HEIGHT, WIDTH)
|
self.linear1 = nn.Linear(IN_DIM, HID_DIM)
|
||||||
self.linear2 = nn.Linear(WIDTH, HEIGHT)
|
self.linear2 = nn.Linear(HID_DIM, IN_DIM)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.linear1(x)
|
x = self.linear1(x)
|
||||||
@ -182,7 +107,7 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||||||
# ==============================
|
# ==============================
|
||||||
# Base Case
|
# Base Case
|
||||||
# ==============================
|
# ==============================
|
||||||
H, W = HEIGHT, WIDTH
|
H, W = IN_DIM, HID_DIM
|
||||||
model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight
|
model_col = nn.Linear(H, W).to(local_rank) # Col parallel weight
|
||||||
weight, bias = model_col.weight, model_col.bias
|
weight, bias = model_col.weight, model_col.bias
|
||||||
|
|
||||||
@ -284,8 +209,11 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||||||
# ==============================
|
# ==============================
|
||||||
# Model Init
|
# Model Init
|
||||||
# ==============================
|
# ==============================
|
||||||
base_model = MlpModel().to(local_rank)
|
# base_model = MlpModel().to(local_rank)
|
||||||
tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
|
# tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
|
||||||
|
base_model = Net(in_dim=IN_DIM, hid_dim=HID_DIM, dtype=dtype).to(local_rank)
|
||||||
|
# Must specify dtype; TPNet init seem to run out of set_default_dtype scope
|
||||||
|
tp_model = TPNet(fc1=base_model.fc1, fc2=base_model.fc2, tp_group=tp_group, dtype=dtype)
|
||||||
|
|
||||||
base_param_group = setup_param_groups(base_model)
|
base_param_group = setup_param_groups(base_model)
|
||||||
tp_param_group = setup_param_groups(tp_model)
|
tp_param_group = setup_param_groups(tp_model)
|
||||||
@ -335,7 +263,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||||||
# ==============================
|
# ==============================
|
||||||
# Correctness Verify
|
# Correctness Verify
|
||||||
# ==============================
|
# ==============================
|
||||||
x = torch.randn(HEIGHT, WIDTH, device=local_rank)
|
x = torch.randn(IN_DIM, HID_DIM, device=local_rank)
|
||||||
|
|
||||||
out = base_model(x)
|
out = base_model(x)
|
||||||
out_tp = tp_model(x)
|
out_tp = tp_model(x)
|
||||||
@ -353,7 +281,9 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||||||
base_optim.zero_grad()
|
base_optim.zero_grad()
|
||||||
dist_optim.zero_grad()
|
dist_optim.zero_grad()
|
||||||
|
|
||||||
for p, tp_p in zip(base_param_group, tp_param_group):
|
base_params = base_model.parameters()
|
||||||
|
tp_params = tp_model.parameters()
|
||||||
|
for p, tp_p in zip(base_params, tp_params):
|
||||||
param_is_distributed = is_distributed_tensor(tp_p)
|
param_is_distributed = is_distributed_tensor(tp_p)
|
||||||
if param_is_distributed:
|
if param_is_distributed:
|
||||||
shard_spec = get_sharding_spec(tp_p)
|
shard_spec = get_sharding_spec(tp_p)
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
import copy
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import nn
|
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
@ -11,17 +8,23 @@ from colossalai.cluster import ProcessGroupMesh
|
|||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.nn.optimizer.came import CAME
|
from colossalai.nn.optimizer.came import CAME
|
||||||
from colossalai.nn.optimizer.distributed_came import DistributedCAME
|
from colossalai.nn.optimizer.distributed_came import DistributedCAME
|
||||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
|
||||||
from colossalai.shardformer.layer._operation import _gather
|
from colossalai.shardformer.layer._operation import _gather
|
||||||
from colossalai.shardformer.layer.utils import Randomizer
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
from colossalai.tensor.d_tensor import get_layout, get_sharding_spec, is_distributed_tensor
|
from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.tensor.d_tensor.sharding_spec import DimSpec
|
from colossalai.tensor.d_tensor.sharding_spec import DimSpec
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_optimizer._utils import check_dist_grad, check_dist_optim_state, check_dist_param, check_optim_states
|
from tests.test_optimizer._utils import (
|
||||||
|
check_dist_grad,
|
||||||
|
check_dist_optim_state,
|
||||||
|
check_dist_param,
|
||||||
|
check_optim_states,
|
||||||
|
set_master_param_to_shard_param,
|
||||||
|
setup_param_groups,
|
||||||
|
)
|
||||||
from tests.test_shardformer.test_model._utils import (
|
from tests.test_shardformer.test_model._utils import (
|
||||||
build_model_from_hybrid_plugin,
|
build_model_from_hybrid_plugin,
|
||||||
build_model_from_low_level_zero_plugin,
|
build_model_from_low_level_zero_plugin,
|
||||||
@ -30,10 +33,12 @@ from tests.test_shardformer.test_model._utils import (
|
|||||||
unwrap_model,
|
unwrap_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
HEIGHT = 128
|
IN_DIM = 128
|
||||||
WIDTH = 128
|
HID_DIM = 128
|
||||||
_TP_SPEC = DimSpec([0])
|
_TP_SPEC = DimSpec([0])
|
||||||
_SEED = 0
|
_SEED = 0
|
||||||
|
Net, data_gen, *_ = next(iter(model_zoo.get_sub_registry("simple_mlp").values()))
|
||||||
|
TPNet, *_ = next(iter(model_zoo.get_sub_registry("simple_tp_mlp").values()))
|
||||||
|
|
||||||
|
|
||||||
def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32):
|
def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torch.dtype = torch.float32):
|
||||||
@ -53,112 +58,6 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc
|
|||||||
assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
|
assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
# setup param groups; (For zero test optim)
|
|
||||||
def setup_param_groups_zero(model: nn.Module) -> list:
|
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
|
||||||
optimizer_grouped_parameters = [
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": 0.1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
return optimizer_grouped_parameters
|
|
||||||
|
|
||||||
|
|
||||||
# setup param groups; (For base optim)
|
|
||||||
def setup_param_groups(model: nn.Module) -> list:
|
|
||||||
optimizer_grouped_parameters = [p for n, p in model.named_parameters()]
|
|
||||||
return optimizer_grouped_parameters
|
|
||||||
|
|
||||||
|
|
||||||
# setup flatten param groups, sharding spec and shape; (For dist optim)
|
|
||||||
def setup_flatten_param_groups_sharding_spec_shape(model: nn.Module) -> dict:
|
|
||||||
flatten_optimizer_grouped_parameters = []
|
|
||||||
sharding_spec = {} # {id(flatten param): get_layout(p).global_shape}
|
|
||||||
param_shape = {} # {id(flatten param): get_sharding_spec(p)}
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
flatten_p = nn.Parameter(p.clone().flatten().requires_grad_(True))
|
|
||||||
flatten_optimizer_grouped_parameters.append(flatten_p)
|
|
||||||
if is_distributed_tensor(p):
|
|
||||||
sharding_spec[id(flatten_p)] = get_sharding_spec(p)
|
|
||||||
param_shape[id(flatten_p)] = get_layout(p).global_shape
|
|
||||||
else:
|
|
||||||
sharding_spec[id(flatten_p)] = None
|
|
||||||
param_shape[id(flatten_p)] = p.shape
|
|
||||||
return flatten_optimizer_grouped_parameters, sharding_spec, param_shape
|
|
||||||
|
|
||||||
|
|
||||||
def set_dist_grad(
|
|
||||||
dist_module: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype, group: dist.ProcessGroup
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Set split grads for Tensor Parallel or ZeRO DP.
|
|
||||||
We do not need a separate treatment for ZeRO,
|
|
||||||
as the wrapper takes care of reduce-scattering grads.
|
|
||||||
"""
|
|
||||||
rank = dist.get_rank(group)
|
|
||||||
world_size = dist.get_world_size(group)
|
|
||||||
|
|
||||||
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
|
|
||||||
if torch_p.grad is None:
|
|
||||||
torch_p.grad = torch.zeros_like(torch_p)
|
|
||||||
|
|
||||||
is_distributed = hasattr(p, "dist_layout")
|
|
||||||
if is_distributed:
|
|
||||||
sharding = p.dist_layout.sharding_spec.sharding_sequence
|
|
||||||
split_dim = sharding.index(_TP_SPEC)
|
|
||||||
shape = torch_p.split(world_size, dim=split_dim)[rank].shape
|
|
||||||
|
|
||||||
indices = torch.arange(shape[split_dim] * rank, shape[split_dim] * (rank + 1))
|
|
||||||
# Generate grads only for the correctly split chunk
|
|
||||||
torch_p.grad.index_add_(split_dim, indices, torch.randn(shape, device=torch_p.device, dtype=g_dtype))
|
|
||||||
|
|
||||||
else:
|
|
||||||
shape = torch_p.shape
|
|
||||||
torch_p.grad += torch.randn(shape, device=torch_p.device, dtype=g_dtype)
|
|
||||||
|
|
||||||
# avoid inconsistent grad and param dtype error
|
|
||||||
orig_p = p.data
|
|
||||||
p.data = torch_p.grad.clone().to(g_dtype)
|
|
||||||
p.grad = p.data
|
|
||||||
p.data = orig_p
|
|
||||||
|
|
||||||
|
|
||||||
def set_master_param_to_shard_param(master_param_list) -> dict:
|
|
||||||
master_param_to_shard_param = {id(p): p for p in master_param_list}
|
|
||||||
return master_param_to_shard_param
|
|
||||||
|
|
||||||
|
|
||||||
class MlpModel(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(MlpModel, self).__init__()
|
|
||||||
self.linear1 = nn.Linear(HEIGHT, WIDTH)
|
|
||||||
self.linear2 = nn.Linear(WIDTH, HEIGHT)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.linear1(x)
|
|
||||||
x = self.linear2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TPModel(nn.Module):
|
|
||||||
def __init__(self, linear1, linear2, tp_group=None):
|
|
||||||
super().__init__()
|
|
||||||
self.linear1 = Linear1D_Col.from_native_module(
|
|
||||||
linear1, process_group=tp_group, gather_output=False, overlap=True
|
|
||||||
)
|
|
||||||
self.linear2 = Linear1D_Row.from_native_module(linear2, process_group=tp_group, parallel_input=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.linear1(x)
|
|
||||||
x = self.linear2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16
|
@parameterize("dtype", [torch.float32]) # torch.float32, torch.float16, torch.bfloat16
|
||||||
@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) # (4, 1), (1, 4)
|
@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) # (4, 1), (1, 4)
|
||||||
def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
||||||
@ -177,12 +76,13 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||||||
# ==============================
|
# ==============================
|
||||||
# Model Init
|
# Model Init
|
||||||
# ==============================
|
# ==============================
|
||||||
base_model = MlpModel().to(local_rank)
|
base_model = Net(in_dim=IN_DIM, hid_dim=HID_DIM, dtype=dtype).to(local_rank)
|
||||||
tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
|
# tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
|
||||||
|
tp_model = TPNet(fc1=base_model.fc1, fc2=base_model.fc2, tp_group=tp_group, dtype=dtype)
|
||||||
|
|
||||||
base_param_group = setup_param_groups(base_model)
|
base_param_group = setup_param_groups(base_model)
|
||||||
tp_param_group = setup_param_groups(tp_model)
|
tp_param_group = setup_param_groups(tp_model)
|
||||||
tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
|
# tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Optimizer Init
|
# Optimizer Init
|
||||||
@ -220,7 +120,7 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||||||
# Correctness Verify
|
# Correctness Verify
|
||||||
# ==============================
|
# ==============================
|
||||||
seed_all(1024)
|
seed_all(1024)
|
||||||
x = torch.randn(WIDTH, HEIGHT, device=local_rank)
|
x = torch.randn(HID_DIM, IN_DIM, device=local_rank)
|
||||||
|
|
||||||
out = base_model(x)
|
out = base_model(x)
|
||||||
out_tp = tp_model(x)
|
out_tp = tp_model(x)
|
||||||
@ -238,7 +138,9 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||||||
base_optim.zero_grad()
|
base_optim.zero_grad()
|
||||||
dist_optim.zero_grad()
|
dist_optim.zero_grad()
|
||||||
|
|
||||||
for p, tp_p in zip(base_param_group, tp_param_group):
|
base_params = base_model.parameters()
|
||||||
|
tp_params = tp_model.parameters()
|
||||||
|
for p, tp_p in zip(base_params, tp_params):
|
||||||
param_is_distributed = is_distributed_tensor(tp_p)
|
param_is_distributed = is_distributed_tensor(tp_p)
|
||||||
if param_is_distributed:
|
if param_is_distributed:
|
||||||
shard_spec = get_sharding_spec(tp_p)
|
shard_spec = get_sharding_spec(tp_p)
|
||||||
@ -256,6 +158,7 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||||||
# No TP bias
|
# No TP bias
|
||||||
pass
|
pass
|
||||||
correctness_verify(p.data, tp_p.data, dtype)
|
correctness_verify(p.data, tp_p.data, dtype)
|
||||||
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
Randomizer.reset_index()
|
Randomizer.reset_index()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
@ -17,7 +16,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_optimizer._utils import check_optim_states, run_bert_test
|
from tests.test_optimizer._utils import check_optim_states, run_bert_test, set_dist_grad
|
||||||
|
|
||||||
_ALLOWED_P_G_TYPES = [
|
_ALLOWED_P_G_TYPES = [
|
||||||
(torch.float, torch.float), # pure fp32
|
(torch.float, torch.float), # pure fp32
|
||||||
@ -109,39 +108,6 @@ def force_assign_grad(p, g_dtype, grad=None):
|
|||||||
p.data = orig_p
|
p.data = orig_p
|
||||||
|
|
||||||
|
|
||||||
def set_dist_grad(
|
|
||||||
dist_module: nn.Module,
|
|
||||||
torch_model: nn.Module,
|
|
||||||
g_dtype: torch.dtype,
|
|
||||||
group: dist.ProcessGroup,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Set grads chunks for Tensor Parallel or ZeRO DP.
|
|
||||||
We do not need a separate treatment for ZeRO,
|
|
||||||
as the LowLevelOptimizer takes care of reduce-scattering grads.
|
|
||||||
"""
|
|
||||||
rank = dist.get_rank(group)
|
|
||||||
world_size = dist.get_world_size(group)
|
|
||||||
|
|
||||||
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
|
|
||||||
if torch_p.grad is None:
|
|
||||||
# avoid inconsistent grad and param dtype error
|
|
||||||
force_assign_grad(torch_p, g_dtype)
|
|
||||||
else:
|
|
||||||
torch_p.grad += torch.randn_like(torch_p, device=torch_p.device, dtype=g_dtype)
|
|
||||||
|
|
||||||
if p.grad is None:
|
|
||||||
force_assign_grad(p, g_dtype)
|
|
||||||
|
|
||||||
if is_distributed_tensor(p):
|
|
||||||
split_dim = get_shard_dim_1d(p)
|
|
||||||
# Add grads only to the correctly split chunk
|
|
||||||
force_assign_grad(p, g_dtype, torch_p.grad.chunk(world_size, dim=split_dim)[rank].contiguous())
|
|
||||||
# assert_close(p.grad, torch_p.grad.chunk(world_size, dim=split_dim)[rank])
|
|
||||||
else:
|
|
||||||
force_assign_grad(p, g_dtype, torch_p.grad)
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
|
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
|
||||||
@parameterize("tp_zero_size", [(4, 1), (1, 4), (2, 2)])
|
@parameterize("tp_zero_size", [(4, 1), (1, 4), (2, 2)])
|
||||||
def run_dist_galore_basic(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]) -> None:
|
def run_dist_galore_basic(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]) -> None:
|
||||||
@ -158,7 +124,7 @@ def run_dist_galore_basic(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_si
|
|||||||
|
|
||||||
dist.get_rank(tp_group)
|
dist.get_rank(tp_group)
|
||||||
seed_all(_SEED) # Fix model init
|
seed_all(_SEED) # Fix model init
|
||||||
torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True, dtype=p_dtype).to(rank)
|
torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, dtype=p_dtype).to(rank)
|
||||||
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group, dtype=p_dtype).to(rank)
|
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group, dtype=p_dtype).to(rank)
|
||||||
assert_distributed_close(tp_model, torch_model, rtol=0, atol=0, tp_group=tp_group)
|
assert_distributed_close(tp_model, torch_model, rtol=0, atol=0, tp_group=tp_group)
|
||||||
|
|
||||||
@ -222,7 +188,7 @@ def run_dist_galore_fwd_bwd(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_
|
|||||||
|
|
||||||
seed_all(_SEED)
|
seed_all(_SEED)
|
||||||
clear_layout_converter() # Ensure correct sharding
|
clear_layout_converter() # Ensure correct sharding
|
||||||
torch_model = Net(_IN_DIM, _HID_DIM, identity=True, dtype=p_dtype).to(rank)
|
torch_model = Net(_IN_DIM, _HID_DIM, dtype=p_dtype).to(rank)
|
||||||
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group, dtype=p_dtype).to(rank)
|
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group, dtype=p_dtype).to(rank)
|
||||||
assert_distributed_close(tp_model, torch_model, rtol=0, atol=0, tp_group=tp_group)
|
assert_distributed_close(tp_model, torch_model, rtol=0, atol=0, tp_group=tp_group)
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
|
|||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_optimizer._utils import check_optim_states, run_bert_test
|
from tests.test_optimizer._utils import check_optim_states, force_assign_grad, run_bert_test, setup_param_groups
|
||||||
|
|
||||||
_ALLOWED_P_G_TYPES = [
|
_ALLOWED_P_G_TYPES = [
|
||||||
(torch.float, torch.float), # pure fp32
|
(torch.float, torch.float), # pure fp32
|
||||||
@ -49,29 +49,6 @@ def assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def setup_param_groups(bert_model: nn.Module) -> list:
|
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
|
||||||
optimizer_grouped_parameters = [
|
|
||||||
{
|
|
||||||
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": 0.1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
return optimizer_grouped_parameters
|
|
||||||
|
|
||||||
|
|
||||||
def force_assign_grad(p, g_dtype, grad=None):
|
|
||||||
"""avoid inconsistent grad and param dtype error"""
|
|
||||||
orig_p = p.data
|
|
||||||
p.data = torch.randn_like(p, device=orig_p.device, dtype=g_dtype) if grad == None else grad
|
|
||||||
p.grad = p.data
|
|
||||||
p.data = orig_p
|
|
||||||
|
|
||||||
|
|
||||||
def set_dist_grad(
|
def set_dist_grad(
|
||||||
dist_module: nn.Module,
|
dist_module: nn.Module,
|
||||||
torch_model: nn.Module,
|
torch_model: nn.Module,
|
||||||
|
Loading…
Reference in New Issue
Block a user