[gemini] update ddp strict mode (#2518)

* [zero] add strict ddp mode for chunk init

* [gemini] update gpt example
This commit is contained in:
HELSON
2023-01-28 14:35:25 +08:00
committed by GitHub
parent 0af793836c
commit 707b11d4a0
16 changed files with 133 additions and 54 deletions

View File

@@ -1,3 +1,4 @@
import math
from copy import copy
from functools import lru_cache
from typing import Callable, Optional, Set
@@ -303,6 +304,11 @@ class ColoTensor(torch.Tensor):
else:
return size_list[args[0]]
def numel_global(self):
"""Returns the number of elements in the tensor when it's replicated.
"""
return math.prod(self.size_global())
# Some API for dist spec check
def is_replicate(self):