mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[tensor] fix a assertion in colo_tensor cross_entropy (#1232)
This commit is contained in:
@@ -23,7 +23,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
|
||||
input_tensor = convert_to_colo_tensor(input_tensor, pg)
|
||||
|
||||
if input_tensor.is_replicate(): # Input is gathered
|
||||
assert target.is_replicate() and weight.is_replicate(), \
|
||||
assert target.is_replicate() and (weight is None or weight.is_replicate()), \
|
||||
"Target tensor and weight tensor both should be complete"
|
||||
output = F.cross_entropy(input_tensor,
|
||||
target,
|
||||
|
Reference in New Issue
Block a user