mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 07:31:19 +00:00
[autoparallel] add cost graph class (#1481)
* [autoparallel] add cost graph class * polish code
This commit is contained in:
@@ -70,7 +70,9 @@ def test_conv_handler():
|
||||
sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||
entire_shape=entire_shape,
|
||||
sharding_sequence=sharding_sequence)
|
||||
strategies_vector_for_input.append(sharding_spec)
|
||||
strategy_name = str(sharding_spec.sharding_sequence)
|
||||
sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
|
||||
strategies_vector_for_input.append(sharding_strategy)
|
||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||
|
||||
# generate conv strategy
|
||||
|
@@ -69,7 +69,9 @@ def test_dot_handler():
|
||||
sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||
entire_shape=entire_shape,
|
||||
sharding_sequence=sharding_sequence)
|
||||
strategies_vector_for_input.append(sharding_spec)
|
||||
strategy_name = str(sharding_spec.sharding_sequence)
|
||||
sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
|
||||
strategies_vector_for_input.append(sharding_strategy)
|
||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||
|
||||
# generate dot strategy
|
||||
|
Reference in New Issue
Block a user