mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 09:29:47 +00:00
Hotfix/Colossalai layers (#92)
* optimized 1d layer apis; reorganized nn.layer modules; fixed tests * fixed 2.5d runtime issue * reworked split batch, now called in trainer.schedule.load_batch Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
@@ -173,7 +173,7 @@ class AccuracyMetric(Metric):
|
||||
self.accumulated_sum.zero_()
|
||||
self.accumulated_correct.zero_()
|
||||
|
||||
def update(self, logits, targets) -> None:
|
||||
def update(self, logits, targets, batch_size) -> None:
|
||||
"""Updates last step accuracy and accumulated accuracy with current logits
|
||||
and labels. It expects the output has logits and labels.
|
||||
|
||||
@@ -187,7 +187,7 @@ class AccuracyMetric(Metric):
|
||||
# update
|
||||
correct = self.acc(logits, targets)
|
||||
|
||||
self.last_step_sum.fill_(targets.size(0))
|
||||
self.last_step_sum.fill_(batch_size)
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
@@ -296,7 +296,8 @@ class AccuracyHook(MetricHook):
|
||||
|
||||
def after_test_iter(self, trainer, logits, targets, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(logits, targets)
|
||||
batch_size = trainer.schedule.batch_size
|
||||
self.metric.update(logits, targets, batch_size)
|
||||
|
||||
|
||||
class ThroughputMetric(Metric):
|
||||
@@ -313,10 +314,8 @@ class ThroughputMetric(Metric):
|
||||
self.last_step_num_samples.zero_()
|
||||
self.last_step_used_time.zero_()
|
||||
|
||||
def update(self, tensor, time) -> None:
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
tensor = tensor[0]
|
||||
self.last_step_num_samples.fill_(tensor.size(0))
|
||||
def update(self, num_samples, time) -> None:
|
||||
self.last_step_num_samples.fill_(num_samples)
|
||||
self.last_step_used_time.fill_(time)
|
||||
self.accumulated_num_samples += self.last_step_num_samples
|
||||
self.accumulated_used_time += self.last_step_used_time
|
||||
@@ -354,11 +353,11 @@ class ThroughputHook(MetricHook):
|
||||
def before_train_epoch(self, trainer):
|
||||
self.metric.reset()
|
||||
|
||||
def after_train_iter(self, trainer, logits, targets, *args):
|
||||
self.metric.update(targets, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
def after_train_iter(self, trainer, *args):
|
||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
|
||||
def before_test(self, trainer):
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, targets, *args):
|
||||
self.metric.update(targets, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
def after_test_iter(self, trainer, *args):
|
||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
|
Reference in New Issue
Block a user