[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)

* Refractor and cleanup using common helper funcs. Tests passed

* Update comments

* Fix relative import

* Fix param fetching bug
This commit is contained in:
Wenxuan Tan
2025-02-11 23:42:34 -06:00
committed by GitHub
parent 5c09d726a6
commit ec73f1b5e2
8 changed files with 142 additions and 298 deletions

View File

@@ -13,7 +13,7 @@ _HID_DIM = 128
class Net(nn.Module):
def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=False, dtype=torch.float32):
def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True, dtype=torch.float32):
super().__init__()
if identity:
self.fc0 = nn.Identity()
@@ -30,7 +30,7 @@ class Net(nn.Module):
class TPNet(nn.Module):
def __init__(
self,
fc0=nn.Linear(_IN_DIM, _IN_DIM),
fc0=nn.Identity(),
fc1=nn.Linear(_IN_DIM, _HID_DIM),
fc2=nn.Linear(_HID_DIM, _IN_DIM),
tp_group=None,