[npu] change device to accelerator api (#5239)

* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-01-09 10:20:05 +08:00
committed by GitHub
parent dd2c28a323
commit d202cc28c0
128 changed files with 1773 additions and 868 deletions

View File

@@ -16,10 +16,10 @@ from utils.global_vars import get_tensorboard_writer, get_timers, set_global_var
from utils.logger import Logger
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.context import ParallelMode
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
@@ -53,7 +53,7 @@ def main():
set_global_variables(launch_time, args.tensorboard_path)
world_size = torch.distributed.get_world_size()
get_current_device()
get_accelerator().get_current_device()
# build model, optimizer and criterion
if args.distplan.startswith("CAI"):
@@ -67,7 +67,10 @@ def main():
# build GPT model
with ColoInitContext(
device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg
device=get_accelerator().get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg,
):
config, model, numel = get_model(args, logger)
@@ -78,7 +81,7 @@ def main():
elif args.distplan == "CAI_Gemini":
gemini_config = dict(
strict_ddp_mode=args.tp_degree == 1,
device=get_current_device(),
device=get_accelerator().get_current_device(),
placement_policy=args.placement,
pin_memory=True,
hidden_dim=model.config.hidden_size,