mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[moe] implement tp
This commit is contained in:
@@ -111,6 +111,7 @@ class GradientStore(BaseStore):
|
||||
|
||||
def reset_all_gradients(self):
|
||||
self._grads_of_params = dict()
|
||||
self.grad_to_param_mapping = dict()
|
||||
|
||||
def get_param_id_for_grad(self, grad: Tensor) -> Optional[int]:
|
||||
"""Return the id of a parameter which the gradient slice belongs to
|
||||
|
Reference in New Issue
Block a user