mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[misc] fit torch api upgradation and remove legecy import (#6093)
* [amp] fit torch's new api * [amp] fix api call * [amp] fix api call * [misc] fit torch pytree api upgrade * [misc] remove legacy import * [misc] fit torch amp api * [misc] fit torch amp api
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
|
||||
|
||||
from .bias_dropout_add import bias_dropout_add_fused_train
|
||||
from .bias_gelu import bias_gelu_impl
|
||||
@@ -45,6 +44,7 @@ def warmup_jit_fusion(
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
"""Compile JIT functions before the main training steps"""
|
||||
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
|
||||
|
||||
embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
|
||||
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
|
||||
|
Reference in New Issue
Block a user