[NFC] polish colossalai/global_variables.py code style (#3259)

Co-authored-by: luchen <luchen@luchendeMBP.lan>
This commit is contained in:
jiangmingyan 2023-03-27 18:48:16 +08:00 committed by binmakeswell
parent 1ff7d5bfa5
commit 488f37048c

View File

@ -1,56 +1,56 @@
from typing import Optional from typing import Optional
class TensorParallelEnv(object): class TensorParallelEnv(object):
_instance = None _instance = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if cls._instance is None: if cls._instance is None:
cls._instance = object.__new__(cls, *args, **kwargs) cls._instance = object.__new__(cls, *args, **kwargs)
return cls._instance return cls._instance
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.load(*args, **kwargs) self.load(*args, **kwargs)
def load(self, def load(self,
mode: Optional[str] = None, mode: Optional[str] = None,
vocab_parallel: bool = False, vocab_parallel: bool = False,
parallel_input_1d: bool = False, parallel_input_1d: bool = False,
summa_dim: int = None, summa_dim: int = None,
tesseract_dim: int = None, tesseract_dim: int = None,
tesseract_dep: int = None, tesseract_dep: int = None,
depth_3d: int = None, depth_3d: int = None,
input_group_3d=None, input_group_3d=None,
weight_group_3d=None, weight_group_3d=None,
output_group_3d=None, output_group_3d=None,
input_x_weight_group_3d=None, input_x_weight_group_3d=None,
output_x_weight_group_3d=None): output_x_weight_group_3d=None):
self.mode = mode self.mode = mode
self.vocab_parallel = vocab_parallel self.vocab_parallel = vocab_parallel
self.parallel_input_1d = parallel_input_1d self.parallel_input_1d = parallel_input_1d
self.summa_dim = summa_dim self.summa_dim = summa_dim
self.tesseract_dim = tesseract_dim self.tesseract_dim = tesseract_dim
self.tesseract_dep = tesseract_dep self.tesseract_dep = tesseract_dep
self.depth_3d = depth_3d self.depth_3d = depth_3d
self.input_group_3d = input_group_3d self.input_group_3d = input_group_3d
self.weight_group_3d = weight_group_3d self.weight_group_3d = weight_group_3d
self.output_group_3d = output_group_3d self.output_group_3d = output_group_3d
self.input_x_weight_group_3d = input_x_weight_group_3d self.input_x_weight_group_3d = input_x_weight_group_3d
self.output_x_weight_group_3d = output_x_weight_group_3d self.output_x_weight_group_3d = output_x_weight_group_3d
def save(self): def save(self):
return dict(mode=self.mode, return dict(mode=self.mode,
vocab_parallel=self.vocab_parallel, vocab_parallel=self.vocab_parallel,
parallel_input_1d=self.parallel_input_1d, parallel_input_1d=self.parallel_input_1d,
summa_dim=self.summa_dim, summa_dim=self.summa_dim,
tesseract_dim=self.tesseract_dim, tesseract_dim=self.tesseract_dim,
tesseract_dep=self.tesseract_dep, tesseract_dep=self.tesseract_dep,
depth_3d=self.depth_3d, depth_3d=self.depth_3d,
input_group_3d=self.input_group_3d, input_group_3d=self.input_group_3d,
weight_group_3d=self.weight_group_3d, weight_group_3d=self.weight_group_3d,
output_group_3d=self.output_group_3d, output_group_3d=self.output_group_3d,
input_x_weight_group_3d=self.input_x_weight_group_3d, input_x_weight_group_3d=self.input_x_weight_group_3d,
output_x_weight_group_3d=self.output_x_weight_group_3d) output_x_weight_group_3d=self.output_x_weight_group_3d)
tensor_parallel_env = TensorParallelEnv() tensor_parallel_env = TensorParallelEnv()