mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-11 08:25:07 +00:00
[hotfix] solver bug caused by dict type comm cost (#1686)
This commit is contained in:
parent
3dd6994427
commit
6878e42248
@ -16,7 +16,6 @@ ELEMENTWISE_FUNC_OP = [
|
|||||||
torch.multiply,
|
torch.multiply,
|
||||||
torch.nn.functional.relu,
|
torch.nn.functional.relu,
|
||||||
torch.nn.functional.dropout,
|
torch.nn.functional.dropout,
|
||||||
torch.flatten,
|
|
||||||
# softmax should not be here
|
# softmax should not be here
|
||||||
torch.nn.functional.softmax
|
torch.nn.functional.softmax
|
||||||
]
|
]
|
||||||
|
@ -69,6 +69,7 @@ class ReshapeHandler(OperatorHandler):
|
|||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
_, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec,
|
_, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec,
|
||||||
replicate_input_sharding_spec)
|
replicate_input_sharding_spec)
|
||||||
|
communication_cost = communication_cost["total"]
|
||||||
|
|
||||||
# generate resharding cost
|
# generate resharding cost
|
||||||
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
|
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
|
||||||
|
@ -319,6 +319,8 @@ class Solver:
|
|||||||
obj = 0
|
obj = 0
|
||||||
for i in range(node_nums):
|
for i in range(node_nums):
|
||||||
assert len(s[i]) == len(c[i])
|
assert len(s[i]) == len(c[i])
|
||||||
|
assert len(s[i]) == len(d[i])
|
||||||
|
|
||||||
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
|
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
|
||||||
|
|
||||||
#############################################
|
#############################################
|
||||||
|
Loading…
Reference in New Issue
Block a user