Hotfix/Colossalai layers (#92)

* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
アマデウス
2021-12-29 23:32:10 +08:00
committed by GitHub
parent 0fedef4f3c
commit 01a80cd86d
71 changed files with 1033 additions and 773 deletions

View File

@@ -173,7 +173,7 @@ class AccuracyMetric(Metric):
self.accumulated_sum.zero_()
self.accumulated_correct.zero_()
def update(self, logits, targets) -> None:
def update(self, logits, targets, batch_size) -> None:
"""Updates last step accuracy and accumulated accuracy with current logits
and labels. It expects the output has logits and labels.
@@ -187,7 +187,7 @@ class AccuracyMetric(Metric):
# update
correct = self.acc(logits, targets)
self.last_step_sum.fill_(targets.size(0))
self.last_step_sum.fill_(batch_size)
self.last_step_correct.fill_(correct)
self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct
@@ -296,7 +296,8 @@ class AccuracyHook(MetricHook):
def after_test_iter(self, trainer, logits, targets, *args):
if self._is_stage_to_compute:
self.metric.update(logits, targets)
batch_size = trainer.schedule.batch_size
self.metric.update(logits, targets, batch_size)
class ThroughputMetric(Metric):
@@ -313,10 +314,8 @@ class ThroughputMetric(Metric):
self.last_step_num_samples.zero_()
self.last_step_used_time.zero_()
def update(self, tensor, time) -> None:
if isinstance(tensor, (list, tuple)):
tensor = tensor[0]
self.last_step_num_samples.fill_(tensor.size(0))
def update(self, num_samples, time) -> None:
self.last_step_num_samples.fill_(num_samples)
self.last_step_used_time.fill_(time)
self.accumulated_num_samples += self.last_step_num_samples
self.accumulated_used_time += self.last_step_used_time
@@ -354,11 +353,11 @@ class ThroughputHook(MetricHook):
def before_train_epoch(self, trainer):
self.metric.reset()
def after_train_iter(self, trainer, logits, targets, *args):
self.metric.update(targets, trainer._timer.get_timer('Train-step').get_elapsed_time())
def after_train_iter(self, trainer, *args):
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
def before_test(self, trainer):
self.metric.reset()
def after_test_iter(self, trainer, logits, targets, *args):
self.metric.update(targets, trainer._timer.get_timer('Test-step').get_elapsed_time())
def after_test_iter(self, trainer, *args):
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())