updated tp layers

This commit is contained in:
kurisusnowdeng
2022-10-26 20:54:39 +08:00
committed by アマデウス
parent cb5a587e9a
commit 0b8161fab8
13 changed files with 645 additions and 293 deletions

View File

@@ -22,7 +22,9 @@ class TensorParallelEnv(object):
depth_3d: int = None,
input_group_3d=None,
weight_group_3d=None,
output_group_3d=None):
output_group_3d=None,
input_x_weight_group_3d=None,
output_x_weight_group_3d=None):
self.mode = mode
self.vocab_parallel = vocab_parallel
self.parallel_input_1d = parallel_input_1d
@@ -33,6 +35,8 @@ class TensorParallelEnv(object):
self.input_group_3d = input_group_3d
self.weight_group_3d = weight_group_3d
self.output_group_3d = output_group_3d
self.input_x_weight_group_3d = input_x_weight_group_3d
self.output_x_weight_group_3d = output_x_weight_group_3d
def save(self):
return dict(mode=self.mode,
@@ -44,7 +48,9 @@ class TensorParallelEnv(object):
depth_3d=self.depth_3d,
input_group_3d=self.input_group_3d,
weight_group_3d=self.weight_group_3d,
output_group_3d=self.output_group_3d)
output_group_3d=self.output_group_3d,
input_x_weight_group_3d=self.input_x_weight_group_3d,
output_x_weight_group_3d=self.output_x_weight_group_3d)
tensor_parallel_env = TensorParallelEnv()