[moe] implement tp

This commit is contained in:
botbw
2024-07-16 06:03:57 +00:00
committed by Hongxin Liu
parent 0b5bbe9ce4
commit dc583aa576
8 changed files with 79 additions and 40 deletions

View File

@@ -111,6 +111,7 @@ class GradientStore(BaseStore):
def reset_all_gradients(self):
self._grads_of_params = dict()
self.grad_to_param_mapping = dict()
def get_param_id_for_grad(self, grad: Tensor) -> Optional[int]:
"""Return the id of a parameter which the gradient slice belongs to