mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-09 03:47:57 +00:00
[bug] shard param during initializing the ShardedModelV2 (#381)
This commit is contained in:
parent
8c18eb0998
commit
272ebfb57d
@ -139,7 +139,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||||||
if self.convert_fp16:
|
if self.convert_fp16:
|
||||||
param.data = param.data.to(torch.half)
|
param.data = param.data.to(torch.half)
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
param.grad = param.grad.to(torch.half).to(target_device)
|
param.grad = param.grad.to(torch.half)
|
||||||
|
|
||||||
# move torch parameters to the target device
|
# move torch parameters to the target device
|
||||||
param.data = param.data.to(target_device)
|
param.data = param.data.to(target_device)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from ast import Try
|
||||||
import functools
|
import functools
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@ -54,7 +55,7 @@ class ShardedModelV2(nn.Module):
|
|||||||
# In case user didn't use ZeroInitContext
|
# In case user didn't use ZeroInitContext
|
||||||
for param in self.module.parameters():
|
for param in self.module.parameters():
|
||||||
if not hasattr(param, 'col_attr'):
|
if not hasattr(param, 'col_attr'):
|
||||||
param.col_attr = ShardedParamV2(param, process_group)
|
param.col_attr = ShardedParamV2(param, process_group, rm_torch_payload=True)
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard([param.col_attr.data])
|
self.shard_strategy.shard([param.col_attr.data])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user