[test] add mixtral transformer test

This commit is contained in:
hxwang
2024-07-02 09:08:41 +00:00
committed by Hongxin Liu
parent f9b6fcf81f
commit 0b76b57cd6
6 changed files with 281 additions and 30 deletions

View File

@@ -1,6 +1,6 @@
import copy
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Type
import torch
import torch.distributed as dist
@@ -117,7 +117,12 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
def build_model_from_hybrid_plugin(
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
model_fn: Callable,
loss_fn: Callable,
test_config: Dict[str, Any],
optim_class=Adam,
sharded_optim_class=Adam,
pluggin_cls: Type[HybridParallelPlugin] = HybridParallelPlugin,
):
use_lazy_init = False
if "use_lazy_init" in test_config:
@@ -149,9 +154,10 @@ def build_model_from_hybrid_plugin(
else:
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
plugin = HybridParallelPlugin(**test_config)
plugin = pluggin_cls(**test_config)
booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)