optimized context test time consumption (#446)

This commit is contained in:
Frank Lee
2022-03-17 14:40:52 +08:00
committed by GitHub
parent 496cbb0760
commit b72b8445c6
8 changed files with 169 additions and 357 deletions

View File

@@ -449,6 +449,7 @@ class ParallelContext:
dist.destroy_process_group(group)
# destroy global process group
dist.destroy_process_group()
self._groups.clear()
def set_device(self, device_ordinal: int = None):
"""Sets distributed processes to be bound to devices.

View File

@@ -13,7 +13,7 @@ def assert_not_equal(a: Tensor, b: Tensor):
def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8):
assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}'
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-2, atol: float = 1e-3):
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
assert_close(a, b, rtol, atol)
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):

View File

@@ -46,6 +46,7 @@ def free_port():
while True:
try:
sock = socket.socket()
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = random.randint(20000, 65000)
sock.bind(('localhost', port))
sock.close()