mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
[shardformer] adapted llama to the new API (#4036)
This commit is contained in:
@@ -47,10 +47,12 @@ class ShardFormer:
|
||||
"""
|
||||
Initialize the distributed process group according to the
|
||||
"""
|
||||
# create process group manager and 1d process group
|
||||
# TODO: may need to support other parallel mode when the config has such as field
|
||||
pg_manager = ProcessGroupManager()
|
||||
if (self.shard_config.tensor_parallel_mode == '1d'):
|
||||
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
|
||||
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
|
||||
self.pg_manager = pg_manager
|
||||
|
||||
return pg_manager
|
||||
|
||||
def shard_model(self, model: nn.Module, policy: Policy = None):
|
||||
|
Reference in New Issue
Block a user