mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 18:39:56 +00:00
[hotfix] Suport extra_kwargs in ShardConfig (#5031)
* [refactor]: replace inference args with extra_kwargs in ShardConfig * modify shardconfig * polish code * fix policy bug in llama * fix bug in auto policy * remove setattr in ShardConfig
This commit is contained in:
@@ -209,7 +209,8 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
|
||||
:class:`Policy`: The auto policy for the model
|
||||
"""
|
||||
full_name = _fullname(model)
|
||||
if shard_config.inference_only:
|
||||
inference_only = shard_config.extra_kwargs.get("inference_only", False)
|
||||
if inference_only:
|
||||
policy_location = _INFER_POLICY_LIST.get(full_name, None)
|
||||
else:
|
||||
policy_location = _POLICY_LIST.get(full_name, None)
|
||||
@@ -219,5 +220,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
|
||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
|
||||
)
|
||||
else:
|
||||
policy = import_policy(policy_location, shard_config.inference_only)
|
||||
policy = import_policy(policy_location, inference_only)
|
||||
return policy()
|
||||
|
Reference in New Issue
Block a user