[misc] fit torch api upgradation and remove legecy import (#6093)

* [amp] fit torch's new api

* [amp] fix api call

* [amp] fix api call

* [misc] fit torch pytree api upgrade

* [misc] remove legacy import

* [misc] fit torch amp api

* [misc] fit torch amp api
This commit is contained in:
Hongxin Liu
2024-10-18 16:48:52 +08:00
committed by GitHub
parent 5ddad486ca
commit 58d8b8a2dd
7 changed files with 20 additions and 12 deletions

View File

@@ -1,7 +1,6 @@
import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
from .bias_dropout_add import bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl
@@ -45,6 +44,7 @@ def warmup_jit_fusion(
dtype: torch.dtype = torch.float32,
):
"""Compile JIT functions before the main training steps"""
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())