mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[test] add mixtral transformer test
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -117,7 +117,12 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
|
||||
|
||||
|
||||
def build_model_from_hybrid_plugin(
|
||||
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
|
||||
model_fn: Callable,
|
||||
loss_fn: Callable,
|
||||
test_config: Dict[str, Any],
|
||||
optim_class=Adam,
|
||||
sharded_optim_class=Adam,
|
||||
pluggin_cls: Type[HybridParallelPlugin] = HybridParallelPlugin,
|
||||
):
|
||||
use_lazy_init = False
|
||||
if "use_lazy_init" in test_config:
|
||||
@@ -149,9 +154,10 @@ def build_model_from_hybrid_plugin(
|
||||
else:
|
||||
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
|
||||
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
||||
|
||||
criterion = loss_fn
|
||||
|
||||
plugin = HybridParallelPlugin(**test_config)
|
||||
plugin = pluggin_cls(**test_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
|
||||
|
Reference in New Issue
Block a user