mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -33,38 +33,34 @@ def test_repeat_interleave():
|
||||
data = torch.tensor([1, 2, 3])
|
||||
materialized_output = torch.repeat_interleave(data, repeats=2)
|
||||
repeat_interleave = partial(patch_fn, repeats=2)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
meta_data = data.to("meta")
|
||||
_assert_output_shape(
|
||||
data=meta_data, patch_fn=repeat_interleave, expect_exception=False, output_shape=materialized_output.shape
|
||||
)
|
||||
|
||||
data = torch.tensor([[1, 2], [3, 4]])
|
||||
materialized_output = torch.repeat_interleave(data, repeats=3, dim=1)
|
||||
repeat_interleave = partial(patch_fn, repeats=3, dim=1)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
meta_data = data.to("meta")
|
||||
_assert_output_shape(
|
||||
data=meta_data, patch_fn=repeat_interleave, expect_exception=False, output_shape=materialized_output.shape
|
||||
)
|
||||
|
||||
data = torch.tensor([[1, 2], [3, 4]])
|
||||
materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=-1)
|
||||
repeat_interleave = partial(patch_fn, repeats=torch.tensor([1, 2]), dim=-1)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
meta_data = data.to("meta")
|
||||
_assert_output_shape(
|
||||
data=meta_data, patch_fn=repeat_interleave, expect_exception=False, output_shape=materialized_output.shape
|
||||
)
|
||||
|
||||
data = torch.tensor([[1, 2], [3, 4]])
|
||||
materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=0)
|
||||
repeat_interleave = partial(patch_fn, repeats=[1, 2], dim=0)
|
||||
meta_data = data.to('meta')
|
||||
_assert_output_shape(data=meta_data,
|
||||
patch_fn=repeat_interleave,
|
||||
expect_exception=True,
|
||||
output_shape=materialized_output.shape)
|
||||
meta_data = data.to("meta")
|
||||
_assert_output_shape(
|
||||
data=meta_data, patch_fn=repeat_interleave, expect_exception=True, output_shape=materialized_output.shape
|
||||
)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
|
Reference in New Issue
Block a user