mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)
* Refractor and cleanup using common helper funcs. Tests passed * Update comments * Fix relative import * Fix param fetching bug
This commit is contained in:
@@ -13,7 +13,7 @@ _HID_DIM = 128
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=False, dtype=torch.float32):
|
||||
def __init__(self, in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True, dtype=torch.float32):
|
||||
super().__init__()
|
||||
if identity:
|
||||
self.fc0 = nn.Identity()
|
||||
@@ -30,7 +30,7 @@ class Net(nn.Module):
|
||||
class TPNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
fc0=nn.Linear(_IN_DIM, _IN_DIM),
|
||||
fc0=nn.Identity(),
|
||||
fc1=nn.Linear(_IN_DIM, _HID_DIM),
|
||||
fc2=nn.Linear(_HID_DIM, _IN_DIM),
|
||||
tp_group=None,
|
||||
|
Reference in New Issue
Block a user