[example] add benchmark (#2276)

* add benchmark

* merge common func

* add total and avg tflops

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2023-01-03 17:20:59 +08:00
committed by GitHub
parent 1405b4381e
commit ac863a01d6
4 changed files with 60 additions and 21 deletions

View File

@@ -240,6 +240,10 @@ class WorkerBase(ABC):
output = [output[i] for i in offsets]
return output
def get_numels(self) -> int:
numel = sum(param.numel() for param in self.module_partition.parameters())
return numel
def get_parameters(self) -> List[torch.Tensor]:
return [p for p in self.module_partition.parameters()]
@@ -1115,6 +1119,15 @@ class PipelineEngineBase(ABC, nn.Module):
for fut in sync_futs:
fut.wait()
def remote_numels(self) -> Dict[int, int]:
numels = {}
actual_stage_num = self._get_actual_stage_num()
for stage_id in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[stage_id]
numel = worker_rref.rpc_sync().get_numels()
numels[stage_id] = numel
return numels
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
parameters = {}
actual_stage_num = self._get_actual_stage_num()