mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[shardformer] test all optimizations (#4399)
[shardformer] test all optimizations [shardformer] test all optimizations [shardformer] test all optimizations
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -16,8 +15,8 @@ from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.auto_policy import Policy
|
||||
from colossalai.shardformer._utils import getattr_
|
||||
from colossalai.shardformer.policies.auto_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
|
||||
|
||||
@@ -156,10 +155,12 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
|
||||
else:
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
sharded_output = sharded_model(**data)
|
||||
|
||||
sharded_loss = criterion(sharded_output)
|
||||
sharded_loss.backward()
|
||||
sharded_optimizer.backward(sharded_loss)
|
||||
|
||||
org_model.train()
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
org_output = org_model(**data)
|
||||
org_loss = criterion(org_output)
|
||||
org_loss.backward()
|
||||
@@ -181,12 +182,12 @@ def check_output_hidden_state(org_output: Tensor,
|
||||
if stage_manager and stage_manager.is_last_stage():
|
||||
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)
|
||||
|
||||
assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \
|
||||
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
|
||||
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
|
||||
|
||||
|
||||
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
|
||||
assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \
|
||||
assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol), \
|
||||
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
|
||||
|
||||
|
||||
@@ -213,7 +214,7 @@ def check_weight(org_model: Module,
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
|
||||
|
||||
assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \
|
||||
assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \
|
||||
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
|
||||
|
||||
|
||||
@@ -244,6 +245,7 @@ def check_grad(org_model: Module,
|
||||
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
||||
|
||||
assert torch.allclose(
|
||||
org_grad, shard_grad, rtol=rtol, atol=atol
|
||||
org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
|
||||
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
|
||||
|
Reference in New Issue
Block a user