[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -8,9 +8,9 @@ def test_get_batch_size():
assert get_batch_size(tensor) == 2
assert get_batch_size([tensor]) == 2
assert get_batch_size((1, tensor)) == 2
assert get_batch_size({'tensor': tensor}) == 2
assert get_batch_size({'dummy': [1], 'tensor': tensor}) == 2
assert get_batch_size({'tensor': [tensor]}) == 2
assert get_batch_size({"tensor": tensor}) == 2
assert get_batch_size({"dummy": [1], "tensor": tensor}) == 2
assert get_batch_size({"tensor": [tensor]}) == 2
def test_get_micro_batch():
@@ -26,12 +26,12 @@ def test_get_micro_batch():
micro_batch = get_micro_batch([x, y], 1, 1)
assert torch.equal(micro_batch[0], x[1:2])
assert torch.equal(micro_batch[1], y[1:2])
micro_batch = get_micro_batch({'x': x, 'y': y}, 0, 1)
assert torch.equal(micro_batch['x'], x[0:1])
assert torch.equal(micro_batch['y'], y[0:1])
micro_batch = get_micro_batch({'x': x, 'y': y}, 1, 1)
assert torch.equal(micro_batch['x'], x[1:2])
assert torch.equal(micro_batch['y'], y[1:2])
micro_batch = get_micro_batch({"x": x, "y": y}, 0, 1)
assert torch.equal(micro_batch["x"], x[0:1])
assert torch.equal(micro_batch["y"], y[0:1])
micro_batch = get_micro_batch({"x": x, "y": y}, 1, 1)
assert torch.equal(micro_batch["x"], x[1:2])
assert torch.equal(micro_batch["y"], y[1:2])
def test_merge_batch():
@@ -42,6 +42,6 @@ def test_merge_batch():
merged = merge_batch([[x[0:1], y[0:1]], [x[1:2], y[1:2]]])
assert torch.equal(merged[0], x)
assert torch.equal(merged[1], y)
merged = merge_batch([{'x': x[0:1], 'y': y[0:1]}, {'x': x[1:2], 'y': y[1:2]}])
assert torch.equal(merged['x'], x)
assert torch.equal(merged['y'], y)
merged = merge_batch([{"x": x[0:1], "y": y[0:1]}, {"x": x[1:2], "y": y[1:2]}])
assert torch.equal(merged["x"], x)
assert torch.equal(merged["y"], y)