Develop/experiments (#59)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
Frank Lee
2021-12-09 15:08:29 +08:00
committed by GitHub
parent eb2f8b1f6b
commit da01c234e1
229 changed files with 6532 additions and 8741 deletions

View File

@@ -1,5 +1,5 @@
from ._base_schedule import BaseSchedule
from ._no_pipeline import NoPipelineSchedule
from ._pipeline import PipelineSchedule
from ._pipeline_schedule import PipelineSchedule
from ._non_pipeline_schedule import NonPipelineSchedule
__all__ = ['BaseSchedule', 'NoPipelineSchedule', 'PipelineSchedule']
__all__ = ['BaseSchedule', 'PipelineSchedule', 'NonPipelineSchedule']

View File

@@ -5,8 +5,10 @@ from abc import ABC, abstractmethod
import torch
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from torch import Tensor
from typing import Iterable, Union, List, Callable
from .._base_engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
@@ -18,8 +20,9 @@ class BaseSchedule(ABC):
control of FP16 in class schedule.
"""
def __init__(self):
self.logger = get_global_dist_logger()
def __init__(self, batch_data_process_func: Callable = None):
self.logger = get_dist_logger()
self.batch_data_process_func = batch_data_process_func
@staticmethod
def _move_tensor(element):
@@ -35,6 +38,11 @@ class BaseSchedule(ABC):
data = data.to(get_current_device()).detach()
return data
def _to_list(self, data):
if torch.is_tensor(data):
return [data]
return data
def load_batch(self, data_iter):
"""Loads a batch from data iterator. It returns the data and labels which are
already in the same GPU as where the model's.
@@ -44,46 +52,34 @@ class BaseSchedule(ABC):
"""
if data_iter is None:
raise RuntimeError('Dataloader is not defined.')
data, label = next(data_iter)
batch_data = next(data_iter)
if self.batch_data_process_func:
data, label = self.batch_data_process_func(batch_data)
else:
data, label = batch_data
data, label = self._to_list(data), self._to_list(label)
return self._move_to_device(data), self._move_to_device(label)
def initialize(self, model, optimizer):
"""Initializes the model and the optimizer before training.
This is often used in FP16 training.
:param model: The neural network model
:param optimizer: Optimizer for updating the parameters
def pre_processing(self, engine: Engine):
"""To perform actions before running the schedule.
"""
return model, optimizer
pass
@abstractmethod
def forward_backward_step(self,
data_iter,
model,
criterion,
optimizer=None,
forward_only=False,
grad_accum_size: int = 1,
return_loss=True):
engine: Engine,
data_iter: Iterable,
forward_only: bool,
return_loss: bool = True
):
"""The process function over a batch of dataset for training or evaluation.
:param data_iter: Data iterator of the dataset
:param model: Model used in training or evaluation
:param optimizer: Optimizer used in training or evaluation
:param criterion: Loss function
:param engine: Colossalai training engine
:param inputs: input data
:param labels: ground truth
:param forward_only: If True, the process won't include backward
:param grad_accum_size: Steps of gradient accumulation
:param return_loss: If False, the loss won't be returned
"""
pass
@abstractmethod
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
"""Updates the parameters with the optimizer.
:param model: The neural network model
:param optimizer: Optimizer for updating the parameters
:param grad_clipping: The norm of gradient clipping
:type grad_clipping: float, optional
"""
pass
pass

View File

@@ -1,188 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
try:
import apex.amp as apex_amp
except:
pass
try:
import torch.cuda.amp as torch_amp
except:
pass
from typing import Iterable
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16, convert_to_fp32
from ..amp import AMP_TYPE, GradScaler
class NoPipelineSchedule(BaseSchedule):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def __init__(
self,
amp_type: AMP_TYPE = None,
amp_config: dict = None,
):
super().__init__()
# mixed precision training
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
'unrecognised value for argument fp16, it can only be None, torch or apex'
self.use_zero_level_2_3 = False
if amp_type is not None:
self.fp16 = True
self.amp_type = amp_type
if amp_config is not None:
assert isinstance(amp_config, dict), \
f'expected argument fp16_config to be type dictionary, but got {type(amp_config)}'
if self.amp_type == AMP_TYPE.TORCH:
# torch apex
if amp_config is None:
amp_config = dict()
self.amp_cfg = amp_config
elif self.amp_type == AMP_TYPE.APEX:
# apex amp
if amp_config is None:
amp_config = dict(opt_level='O2')
self.logger.warning(
'apex is deprecated, please consider using torch.cuda.amp instead.'
)
self.amp_cfg = amp_config
elif self.amp_type == AMP_TYPE.PARALLEL:
# use fp16 optimizer for tensor parallel training
if amp_config is None:
amp_config = dict()
self.amp_cfg = amp_config
else:
self.fp16 = False
self.amp_type = None
def initialize(self, model: nn.Module, optimizer: Optimizer):
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
self.use_zero_level_2_3 = True
assert self.amp_type != AMP_TYPE.PARALLEL, \
'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
if self.fp16:
if self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler = GradScaler(**self.amp_cfg)
elif self.amp_type == AMP_TYPE.APEX:
model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg)
return model, optimizer
def forward_backward_step(self,
data_iter: Iterable,
model: nn.Module,
criterion: nn.modules.loss._Loss,
optimizer: Optimizer = None,
forward_only: bool = False,
grad_accum_size: int = 1,
return_loss: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
:param model: Model for training and inference
:param criterion: Loss function for training
:param optimizer: Optimizer used for training
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
:param grad_accum_size: The number of iterations for gradient accumulation
:param return_loss: Loss will be returned if True
:type data_iter: Iterator
:type model: torch.nn.Module
:type criterion: torch.nn.modules.loss._Loss
:type optimizer: torch.optim.Optimizer
:type forward_only: bool, optional
:type grad_accum_size: int
:type return_loss: bool, optional
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
data, label = self.load_batch(data_iter)
loss = None
# forward
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
with torch_amp.autocast():
output = model(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = criterion(*output, *label)
else:
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
data = convert_to_fp16(data)
output = model(*data)
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
output = convert_to_fp32(output)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = criterion(*output, *label)
loss /= grad_accum_size
if not forward_only:
# backward
if self.use_zero_level_2_3:
optimizer.backward(loss)
elif self.fp16:
if self.amp_type == AMP_TYPE.APEX:
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
elif self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.scale(loss).backward()
elif self.amp_type == AMP_TYPE.PARALLEL:
loss = optimizer.scale_loss(loss)
loss.backward()
# scale back to display the original value in logs
loss.div_(optimizer.grad_scaler.scale)
else:
loss.backward()
if return_loss:
return output, label, loss * grad_accum_size
else:
return output, None, None
def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
# step optimizer
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
if grad_clipping > 0.0:
self._torch_amp_scaler.unscale_(optimizer)
clip_grad_norm_fp32(model.parameters(), grad_clipping)
self._torch_amp_scaler.step(optimizer)
self._torch_amp_scaler.update()
else:
if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0:
clip_grad_norm_fp32(model.parameters(), grad_clipping)
optimizer.step()

View File

@@ -0,0 +1,61 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Iterable
import torch
import torch.nn as nn
from colossalai.engine import Engine
from torch.optim import Optimizer
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context
class NonPipelineSchedule(BaseSchedule):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def forward_backward_step(self,
engine: Engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:param engine: Model for training and inference
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
:param return_loss: Loss will be returned if True
:type engine: Iterator
:type data_iter: Iterator
:type forward_only: bool, optional
:type return_loss: bool, optional
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
data, label = self.load_batch(data_iter)
# forward
with conditional_context(torch.no_grad(), enable=forward_only):
output = engine(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = engine.criterion(*output, *label)
if not forward_only:
engine.backward(loss)
if return_loss:
return output, label, loss
else:
return output, None, None

View File

@@ -10,12 +10,12 @@ from torch import Tensor
from colossalai.communication import *
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import get_current_device
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16
from ..amp import AMP_TYPE
from colossalai.amp import AMP_TYPE
def squeeze(x: Union[Tensor, tuple, list]):
@@ -28,32 +28,25 @@ def squeeze(x: Union[Tensor, tuple, list]):
class PipelineSchedule(BaseSchedule):
"""A helper schedule class for pipeline parallelism running environment.
It uses non-interleaved 1F1B strategy. Other properties are similar as
:class:`NoPipelineSchedule`.
:class:`NonPipelineSchedule`.
:param num_microbatches: The number of microbatches
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:param sync_data: If set to `True`, will sync data every batch over pipeline stages
:type num_microbatches: int
:type amp_type: AMP_TYPE
:type amp_config: dict
:type sync_data: bool
"""
def __init__(self,
num_microbatches,
amp_type: AMP_TYPE = None,
amp_config: dict = None):
sync_data: bool = True):
super().__init__()
self.num_microbatches = num_microbatches
self.data_sync = True # close after making sure data is identical
# amp
# LSGL: amp_config is not used, but leave here for future extension
self.amp_type = amp_type
self.amp_config = amp_config
if self.amp_type is not None:
assert self.amp_type == AMP_TYPE.PARALLEL, 'We only support AMP_TYPE.PARALLEL for pipeline training for now'
self.sync_data = sync_data
def _move_to_device(self, data):
if isinstance(data, (
@@ -67,30 +60,37 @@ class PipelineSchedule(BaseSchedule):
return data
def _sync_data(self):
reqs = []
if gpc.is_first_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_global_rank()
dist.broadcast(
reqs.append(dist.broadcast(
tensor=self.batch_data,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
)
dist.broadcast(
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
async_op=True
))
reqs.append(dist.broadcast(
tensor=self.batch_label,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
)
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
async_op=True
))
if gpc.is_last_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
dist.broadcast(
reqs.append(dist.broadcast(
tensor=self.batch_data,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
)
dist.broadcast(
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
async_op=True
))
reqs.append(dist.broadcast(
tensor=self.batch_label,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
)
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
async_op=True
))
for req in reqs:
req.wait()
# Pipeline schedule just puts data in memory
def load_batch(self, data_iter):
@@ -104,7 +104,7 @@ class PipelineSchedule(BaseSchedule):
assert batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = batch_size // self.num_microbatches
if self.data_sync:
if self.sync_data:
self._sync_data()
def _get_data_slice(self, tensor):
@@ -116,21 +116,20 @@ class PipelineSchedule(BaseSchedule):
self.batch_pos += self.microbatch_size
return (data,), (label,)
def initialize(self, model, optimizer):
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
def pre_processing(self, engine):
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
)
# LSG: set default dtype to fp16 for communication
if self.amp_type == AMP_TYPE.PARALLEL:
if isinstance(engine.model, NaiveAMPModel):
torch.set_default_dtype(torch.half)
self.logger.info(
self.logger.warning(
'default tensor dtype is set to torch.half for fp16 training',
ranks=[0])
def forward_step(self, model, criterion, input_tensor, return_tensors,
grad_accum_size, return_loss=True):
def forward_step(self, engine, input_tensor, return_tensors, return_loss=True):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
@@ -138,17 +137,16 @@ class PipelineSchedule(BaseSchedule):
if input_tensor is None:
input_tensor, label = self.load_micro_batch()
if self.amp_type == AMP_TYPE.PARALLEL:
input_tensor = convert_to_fp16(input_tensor)
input_tensor = squeeze(input_tensor)
output_tensor = model(input_tensor)
output_tensor = engine(input_tensor)
output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_loss:
input_tensor, label = self.load_micro_batch()
loss_reduced = criterion(output_tensor, *label) \
/ (self.num_microbatches * grad_accum_size)
loss_reduced = engine.criterion(output_tensor, *label) \
/ self.num_microbatches
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced
@@ -159,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
else:
return output_tensor
def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad):
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
@@ -171,9 +169,10 @@ class PipelineSchedule(BaseSchedule):
input_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
if output_tensor_grad is None:
engine.backward(output_tensor)
else:
engine.backward_by_grad(output_tensor, output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
@@ -183,12 +182,9 @@ class PipelineSchedule(BaseSchedule):
return input_tensor_grad
def forward_backward_step(self,
engine,
data_iter,
model,
criterion,
optimizer=None,
forward_only=False,
grad_accum_size: int = 1,
return_loss=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
@@ -226,9 +222,8 @@ class PipelineSchedule(BaseSchedule):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape)
output_tensor = self.forward_step(
model, criterion,
input_tensor, return_tensors,
grad_accum_size, return_loss=return_loss
engine, input_tensor, return_tensors,
return_loss=return_loss
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
@@ -252,9 +247,8 @@ class PipelineSchedule(BaseSchedule):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step(
model, criterion,
input_tensor, return_tensors,
grad_accum_size, return_loss=return_loss
engine, input_tensor, return_tensors,
return_loss=return_loss
)
if forward_only:
send_forward(output_tensor)
@@ -276,7 +270,7 @@ class PipelineSchedule(BaseSchedule):
output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step(
optimizer,
engine,
input_tensor, output_tensor,
output_tensor_grad
)
@@ -297,7 +291,7 @@ class PipelineSchedule(BaseSchedule):
output_tensor_grad = recv_backward(bt_shape)
input_tensor_grad = self.backward_step(
optimizer,
engine,
input_tensor, output_tensor,
output_tensor_grad
)
@@ -309,11 +303,8 @@ class PipelineSchedule(BaseSchedule):
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
sum(loss) * grad_accum_size)
sum(loss))
else:
return tuple((torch.cat(return_tensors, dim=0), None, None))
else:
return tuple((None, None, None))
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
optimizer.step()

View File

@@ -1,27 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union, List
from torch import Tensor
def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.half()
elif isinstance(data, (list, tuple)):
ret = [val.half() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.float()
elif isinstance(data, (list, tuple)):
ret = [val.float() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret