mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[shardformer] Pytree fix (#4533)
* pytree test * test bert * test bert * test bert * revise * add register * add register
This commit is contained in:
@@ -6,12 +6,21 @@ import torch.cuda
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||
from ._utils import (
|
||||
detach,
|
||||
get_batch_size,
|
||||
get_micro_batch,
|
||||
merge_batch,
|
||||
model_forward,
|
||||
retain_grad,
|
||||
to_device,
|
||||
tree_map_hf,
|
||||
)
|
||||
from .base import PipelineSchedule
|
||||
|
||||
|
||||
@@ -154,7 +163,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
outputs.append(tree_map(detach, output_obj))
|
||||
outputs.append(tree_map_hf(detach, output_obj))
|
||||
return loss
|
||||
else:
|
||||
return output_obj
|
||||
@@ -302,5 +311,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
self.send_backward(input_obj_grad)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0))
|
||||
return {'loss': accum_loss, 'outputs': outputs}
|
||||
|
Reference in New Issue
Block a user