mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 19:55:29 +00:00
[shardformer] added embedding gradient check (#4124)
This commit is contained in:
parent
44a190e6ac
commit
ae035d305d
@ -55,7 +55,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
if ignore:
|
if ignore:
|
||||||
return
|
return
|
||||||
raise AttributeError(f"Object {obj} has no attribute {attr}")
|
raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}")
|
||||||
setattr(obj, attrs[-1], value)
|
setattr(obj, attrs[-1], value)
|
||||||
|
|
||||||
|
|
||||||
@ -76,5 +76,5 @@ def getattr_(obj, attr: str, ignore: bool = False):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
if ignore:
|
if ignore:
|
||||||
return None
|
return None
|
||||||
raise AttributeError(f"Object {obj} has no attribute {attr}")
|
raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}")
|
||||||
return obj
|
return obj
|
||||||
|
@ -97,7 +97,7 @@ class BertPolicy(Policy):
|
|||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
target_module=col_nn.DropoutForReplicatedInput,
|
||||||
)
|
)
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
|
from .._utils import getattr_, setattr_
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
|
|
||||||
@ -73,7 +75,6 @@ class BloomPolicy(Policy):
|
|||||||
r"""
|
r"""
|
||||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||||
"""
|
"""
|
||||||
# TODO:
|
|
||||||
vocab_size = self.model.config.vocab_size
|
vocab_size = self.model.config.vocab_size
|
||||||
world_size = self.shard_config.tensor_parallel_size
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
if vocab_size % world_size != 0:
|
if vocab_size % world_size != 0:
|
||||||
@ -161,13 +162,12 @@ class BloomPolicy(Policy):
|
|||||||
|
|
||||||
def new_model_class(self):
|
def new_model_class(self):
|
||||||
# do nothing
|
# do nothing
|
||||||
return self.model
|
return None
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
# BertModel
|
|
||||||
class BloomModelPolicy(BloomPolicy):
|
class BloomModelPolicy(BloomPolicy):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -191,6 +191,19 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
|||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
|
||||||
|
for k, v in binding_map.items():
|
||||||
|
param = getattr_(self.model, k)
|
||||||
|
|
||||||
|
if not isinstance(param, nn.Parameter):
|
||||||
|
param = nn.Parameter(param)
|
||||||
|
|
||||||
|
# tie weights
|
||||||
|
setattr_(self.model, k, param)
|
||||||
|
setattr_(self.model, v, param)
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
class BloomForSequenceClassificationPolicy(BloomPolicy):
|
class BloomForSequenceClassificationPolicy(BloomPolicy):
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||||
|
|
||||||
|
from .._utils import getattr_, setattr_
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -35,7 +36,7 @@ class OPTPolicy(Policy):
|
|||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="embed_tokens",
|
suffix="embed_tokens",
|
||||||
target_module=Embedding1D,
|
target_module=VocabParallelEmbedding1D,
|
||||||
)
|
)
|
||||||
]),
|
]),
|
||||||
OPTDecoderLayer:
|
OPTDecoderLayer:
|
||||||
@ -127,6 +128,18 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
binding_map = {
|
||||||
|
'model.decoder.embed_tokens': 'lm_head',
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v in binding_map.items():
|
||||||
|
src_mod = getattr_(self.model, k)
|
||||||
|
dst_mod = getattr_(self.model, v)
|
||||||
|
dst_mod.weight = src_mod.weight
|
||||||
|
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
class OPTForSequenceClassificationPolicy(OPTPolicy):
|
class OPTForSequenceClassificationPolicy(OPTPolicy):
|
||||||
|
|
||||||
|
@ -1,11 +1,20 @@
|
|||||||
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer import (
|
||||||
|
DropoutForParallelInput,
|
||||||
|
Embedding1D,
|
||||||
|
FusedRMSNorm,
|
||||||
|
Linear1D_Col,
|
||||||
|
Linear1D_Row,
|
||||||
|
VocabParallelEmbedding1D,
|
||||||
|
)
|
||||||
|
from colossalai.shardformer.policies.basepolicy import ModulePolicyDescription
|
||||||
|
|
||||||
|
from .._utils import getattr_, setattr_
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
||||||
|
|
||||||
|
|
||||||
class T5ModelPolicy(Policy):
|
class T5BasePolicy(Policy):
|
||||||
|
|
||||||
def config_sanity_check(self):
|
def config_sanity_check(self):
|
||||||
pass
|
pass
|
||||||
@ -33,7 +42,7 @@ class T5ModelPolicy(Policy):
|
|||||||
T5Stack,
|
T5Stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
base_policy = {
|
||||||
T5Stack:
|
T5Stack:
|
||||||
ModulePolicyDescription(attribute_replacement={},
|
ModulePolicyDescription(attribute_replacement={},
|
||||||
param_replacement=[],
|
param_replacement=[],
|
||||||
@ -41,6 +50,10 @@ class T5ModelPolicy(Policy):
|
|||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="dropout",
|
suffix="dropout",
|
||||||
target_module=DropoutForParallelInput,
|
target_module=DropoutForParallelInput,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="embed_tokens",
|
||||||
|
target_module=Embedding1D,
|
||||||
)
|
)
|
||||||
]),
|
]),
|
||||||
T5LayerSelfAttention:
|
T5LayerSelfAttention:
|
||||||
@ -158,30 +171,86 @@ class T5ModelPolicy(Policy):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
|
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
|
||||||
|
|
||||||
|
for k, v in binding_map:
|
||||||
|
mod = getattr_(self.model, k)
|
||||||
|
setattr_(self.model, v, mod)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
class T5ModelPolicy(T5BasePolicy):
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers import T5Model
|
||||||
|
|
||||||
|
base_policy = super().module_policy()
|
||||||
|
base_policy[T5Model] = ModulePolicyDescription(attribute_replacement={},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="shared",
|
||||||
|
target_module=VocabParallelEmbedding1D,
|
||||||
|
)
|
||||||
|
])
|
||||||
|
return base_policy
|
||||||
|
|
||||||
|
|
||||||
|
class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers import T5ForConditionalGeneration
|
from transformers import T5ForConditionalGeneration
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
policy[T5ForConditionalGeneration] = ModulePolicyDescription(attribute_replacement={},
|
||||||
new_item = {
|
param_replacement=[],
|
||||||
T5ForConditionalGeneration:
|
sub_module_replacement=[
|
||||||
ModulePolicyDescription(attribute_replacement={},
|
SubModuleReplacementDescription(
|
||||||
param_replacement=[],
|
suffix="shared",
|
||||||
sub_module_replacement=[
|
target_module=VocabParallelEmbedding1D,
|
||||||
SubModuleReplacementDescription(suffix="lm_head",
|
),
|
||||||
target_module=Linear1D_Col,
|
SubModuleReplacementDescription(
|
||||||
kwargs=dict(gather_output=True))
|
suffix="lm_head",
|
||||||
])
|
target_module=Linear1D_Col,
|
||||||
}
|
kwargs=dict(gather_output=True))
|
||||||
|
])
|
||||||
policy.update(new_item)
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
super().postprocess()
|
||||||
|
|
||||||
class T5EncoderPolicy(T5ModelPolicy):
|
binding_map = {"shared": "lm_head"}
|
||||||
pass
|
|
||||||
|
for k, v in binding_map.items():
|
||||||
|
src_mod = getattr_(self.model, k)
|
||||||
|
dst_mod = getattr_(self.model, v)
|
||||||
|
dst_mod.weight = src_mod.weight
|
||||||
|
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
class T5EncoderPolicy(T5BasePolicy):
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers import T5EncoderModel
|
||||||
|
|
||||||
|
base_policy = super().module_policy()
|
||||||
|
base_policy[T5EncoderModel] = ModulePolicyDescription(attribute_replacement={},
|
||||||
|
param_replacement=[],
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="shared",
|
||||||
|
target_module=VocabParallelEmbedding1D,
|
||||||
|
)
|
||||||
|
])
|
||||||
|
return base_policy
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
binding_map = [
|
||||||
|
["shared", "encoder.embed_tokens"],
|
||||||
|
]
|
||||||
|
|
||||||
|
for k, v in binding_map:
|
||||||
|
mod = getattr_(self.model, k)
|
||||||
|
setattr_(self.model, v, mod)
|
||||||
|
return self.model
|
||||||
|
@ -38,17 +38,6 @@ class ModelSharder(object):
|
|||||||
self._replace_module()
|
self._replace_module()
|
||||||
self._postprocess()
|
self._postprocess()
|
||||||
|
|
||||||
def reshape_embedding(self) -> None:
|
|
||||||
r"""
|
|
||||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
|
||||||
"""
|
|
||||||
vocab_size = self.model_config.vocab_size
|
|
||||||
world_size = self.shard_config.world_size
|
|
||||||
if vocab_size % world_size != 0:
|
|
||||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
|
||||||
self.model.resize_token_embeddings(new_vocab_size)
|
|
||||||
self.model_config = self.model.config
|
|
||||||
|
|
||||||
def _preprocess(self) -> None:
|
def _preprocess(self) -> None:
|
||||||
self.model = self.policy.preprocess()
|
self.model = self.policy.preprocess()
|
||||||
|
|
||||||
|
@ -70,6 +70,8 @@ class ModelZooRegistry(dict):
|
|||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
if keyword in k:
|
if keyword in k:
|
||||||
new_dict[k] = v
|
new_dict[k] = v
|
||||||
|
|
||||||
|
assert len(new_dict) > 0, f'No model found with keyword {keyword}'
|
||||||
return new_dict
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,20 +18,35 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
org_loss.backward()
|
org_loss.backward()
|
||||||
shard_loss.backward()
|
shard_loss.backward()
|
||||||
|
|
||||||
# check grad equality
|
assert torch.allclose(org_loss, shard_loss,
|
||||||
|
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||||
|
|
||||||
|
# check grad
|
||||||
|
|
||||||
if org_model.__class__.__name__ == 'BertModel':
|
if org_model.__class__.__name__ == 'BertModel':
|
||||||
org_grad = org_model.encoder.layer[0].attention.self.query.weight.grad
|
bert = org_model
|
||||||
shard_grad = sharded_model.encoder.layer[0].attention.self.query.weight.grad
|
sharded_bert = sharded_model
|
||||||
else:
|
else:
|
||||||
org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad
|
bert = org_model.bert
|
||||||
shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad
|
sharded_bert = sharded_model.bert
|
||||||
|
|
||||||
|
# compare self attention grad
|
||||||
|
org_grad = bert.encoder.layer[0].attention.self.query.weight.grad
|
||||||
|
shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad
|
||||||
|
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
assert torch.allclose(org_loss, shard_loss,
|
# compare embedding grad
|
||||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
org_grad = bert.embeddings.word_embeddings.weight.grad
|
||||||
|
shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad
|
||||||
|
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
@ -18,20 +18,36 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
org_loss.backward()
|
org_loss.backward()
|
||||||
shard_loss.backward()
|
shard_loss.backward()
|
||||||
|
|
||||||
# check grad equality
|
assert torch.allclose(org_loss, shard_loss,
|
||||||
|
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||||
|
|
||||||
|
# unwrap model
|
||||||
if org_model.__class__.__name__ == 'BloomModel':
|
if org_model.__class__.__name__ == 'BloomModel':
|
||||||
org_grad = org_model.h[0].self_attention.query_key_value.weight.grad
|
bloom = org_model
|
||||||
shard_grad = sharded_model.h[0].self_attention.query_key_value.weight.grad
|
sharded_bloom = sharded_model
|
||||||
else:
|
else:
|
||||||
org_grad = org_model.transformer.h[0].self_attention.query_key_value.weight.grad
|
bloom = org_model.transformer
|
||||||
shard_grad = sharded_model.transformer.h[0].self_attention.query_key_value.weight.grad
|
sharded_bloom = sharded_model.transformer
|
||||||
|
|
||||||
|
# check attention grad
|
||||||
|
org_grad = bloom.h[0].self_attention.query_key_value.weight.grad
|
||||||
|
shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad
|
||||||
|
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
|
||||||
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
# check embedding weights
|
||||||
|
org_grad = bloom.word_embeddings.weight.grad
|
||||||
|
shard_grad = sharded_bloom.word_embeddings.weight.grad
|
||||||
|
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
|
||||||
assert torch.allclose(org_loss, shard_loss,
|
|
||||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
@ -18,20 +18,36 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
org_loss.backward()
|
org_loss.backward()
|
||||||
shard_loss.backward()
|
shard_loss.backward()
|
||||||
|
|
||||||
# check grad equality
|
assert torch.allclose(org_loss, shard_loss,
|
||||||
|
atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}"
|
||||||
|
|
||||||
|
# unwrap model
|
||||||
if org_model.__class__.__name__ == 'GPT2Model':
|
if org_model.__class__.__name__ == 'GPT2Model':
|
||||||
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
org_model = org_model
|
||||||
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
sharded_model = sharded_model
|
||||||
else:
|
else:
|
||||||
org_grad = org_model.transformer.h[0].mlp.c_fc.weight.grad
|
org_model = org_model.transformer
|
||||||
shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad
|
sharded_model = sharded_model.transformer
|
||||||
|
|
||||||
|
# check mlp grad
|
||||||
|
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
||||||
|
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
||||||
|
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||||
|
|
||||||
assert torch.allclose(org_loss, shard_loss,
|
assert torch.allclose(
|
||||||
atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}"
|
org_grad, all_shard_grad,
|
||||||
|
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
# check embedding weights
|
||||||
|
org_grad = org_model.wte.weight.grad
|
||||||
|
shard_grad = sharded_model.wte.weight.grad
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
org_grad, all_shard_grad,
|
org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
@ -23,7 +23,10 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
org_loss.backward()
|
org_loss.backward()
|
||||||
shard_loss.backward()
|
shard_loss.backward()
|
||||||
|
|
||||||
# check grad
|
assert torch.allclose(org_loss, shard_loss,
|
||||||
|
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||||
|
|
||||||
|
# unwrap model
|
||||||
if hasattr(org_model, 'model'):
|
if hasattr(org_model, 'model'):
|
||||||
llama_model = org_model.model
|
llama_model = org_model.model
|
||||||
shard_llama_model = sharded_model.model
|
shard_llama_model = sharded_model.model
|
||||||
@ -31,14 +34,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
llama_model = org_model
|
llama_model = org_model
|
||||||
shard_llama_model = sharded_model
|
shard_llama_model = sharded_model
|
||||||
|
|
||||||
|
# check attention grad
|
||||||
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
|
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
|
||||||
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
|
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||||
|
|
||||||
assert torch.allclose(org_loss, shard_loss,
|
# check embedding grad
|
||||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
org_grad = llama_model.embed_tokens.weight.grad
|
||||||
|
shard_grad = shard_llama_model.embed_tokens.weight.grad
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||||
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||||
|
|
||||||
|
@ -28,7 +28,10 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
org_loss.backward()
|
org_loss.backward()
|
||||||
shard_loss.backward()
|
shard_loss.backward()
|
||||||
|
|
||||||
# check grad
|
assert torch.allclose(org_loss, shard_loss,
|
||||||
|
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||||
|
|
||||||
|
# unwrap model
|
||||||
if hasattr(org_model, 'model'):
|
if hasattr(org_model, 'model'):
|
||||||
opt_model = org_model.model
|
opt_model = org_model.model
|
||||||
shard_opt_model = sharded_model.model
|
shard_opt_model = sharded_model.model
|
||||||
@ -36,16 +39,23 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
opt_model = org_model
|
opt_model = org_model
|
||||||
shard_opt_model = sharded_model
|
shard_opt_model = sharded_model
|
||||||
|
|
||||||
|
# check attention grad
|
||||||
org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
||||||
shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
||||||
|
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
assert torch.allclose(org_loss, shard_loss,
|
|
||||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
# check embedding grad
|
||||||
|
org_grad = opt_model.decoder.embed_tokens.weight.grad
|
||||||
|
shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||||
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
def check_OPTModel(rank, world_size, port):
|
def check_OPTModel(rank, world_size, port):
|
||||||
@ -65,3 +75,7 @@ def check_OPTModel(rank, world_size, port):
|
|||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def test_OPTModel():
|
def test_OPTModel():
|
||||||
spawn(check_OPTModel, 4)
|
spawn(check_OPTModel, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_OPTModel()
|
||||||
|
@ -21,19 +21,43 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
org_loss.backward()
|
org_loss.backward()
|
||||||
shard_loss.backward()
|
shard_loss.backward()
|
||||||
|
|
||||||
# check grad equality
|
assert torch.allclose(org_loss, shard_loss,
|
||||||
|
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||||
|
|
||||||
|
# check attention grad
|
||||||
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||||
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||||
|
|
||||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
|
||||||
assert torch.allclose(org_loss, shard_loss,
|
|
||||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
|
||||||
assert torch.allclose(org_grad, all_shard_grad,
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||||
|
|
||||||
|
# check self attention embed
|
||||||
|
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad
|
||||||
|
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||||
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
# check token embedding grad
|
||||||
|
org_grad = org_model.shared.weight.grad
|
||||||
|
|
||||||
|
# check weights are tied
|
||||||
|
if hasattr(org_model, 'lm_head'):
|
||||||
|
assert org_model.shared.weight.data.data_ptr() == org_model.lm_head.weight.data.data_ptr()
|
||||||
|
assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr()
|
||||||
|
|
||||||
|
shard_grad = sharded_model.shared.weight.grad
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
|
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
def check_t5(rank, world_size, port):
|
def check_t5(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
@ -44,7 +68,6 @@ def check_t5(rank, world_size, port):
|
|||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
org_model, sharded_model = build_model(model_fn)
|
org_model, sharded_model = build_model(model_fn)
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@ -56,4 +79,4 @@ def test_t5():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_t5()
|
test_t5()
|
||||||
|
@ -45,6 +45,7 @@ def check_vit(rank, world_size, port):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.skip
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def test_vit():
|
def test_vit():
|
||||||
|
Loading…
Reference in New Issue
Block a user