mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[booster] implemented the torch ddd + resnet example (#3232)
* [booster] implemented the torch ddd + resnet example * polish code
This commit is contained in:
25
colossalai/interface/model.py
Normal file
25
colossalai/interface/model.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ModelWrapper(nn.Module):
|
||||
"""
|
||||
A wrapper class to define the common interface used by booster.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The model to be wrapped.
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def unwrap(self):
|
||||
"""
|
||||
Unwrap the model to return the original model for checkpoint saving/loading.
|
||||
"""
|
||||
if isinstance(self.module, ModelWrapper):
|
||||
return self.module.unwrap()
|
||||
return self.module
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
Reference in New Issue
Block a user