mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
[test] modify model supporting part of low_level_zero plugin (including correspoding docs)
This commit is contained in:
parent
d1fcc0fa4d
commit
db40e086c8
@ -44,12 +44,6 @@ We've tested compatibility on some famous models, following models may not be su
|
|||||||
|
|
||||||
- `timm.models.convit_base`
|
- `timm.models.convit_base`
|
||||||
- dlrm and deepfm models in `torchrec`
|
- dlrm and deepfm models in `torchrec`
|
||||||
- `diffusers.VQModel`
|
|
||||||
- `transformers.AlbertModel`
|
|
||||||
- `transformers.AlbertForPreTraining`
|
|
||||||
- `transformers.BertModel`
|
|
||||||
- `transformers.BertForPreTraining`
|
|
||||||
- `transformers.GPT2DoubleHeadsModel`
|
|
||||||
|
|
||||||
Compatibility problems will be fixed in the future.
|
Compatibility problems will be fixed in the future.
|
||||||
|
|
||||||
|
@ -42,12 +42,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
|
|||||||
|
|
||||||
- `timm.models.convit_base`
|
- `timm.models.convit_base`
|
||||||
- dlrm and deepfm models in `torchrec`
|
- dlrm and deepfm models in `torchrec`
|
||||||
- `diffusers.VQModel`
|
|
||||||
- `transformers.AlbertModel`
|
|
||||||
- `transformers.AlbertForPreTraining`
|
|
||||||
- `transformers.BertModel`
|
|
||||||
- `transformers.BertForPreTraining`
|
|
||||||
- `transformers.GPT2DoubleHeadsModel`
|
|
||||||
|
|
||||||
兼容性问题将在未来修复。
|
兼容性问题将在未来修复。
|
||||||
|
|
||||||
|
@ -53,16 +53,6 @@ def output_transform_fn(x):
|
|||||||
return dict(output=x)
|
return dict(output=x)
|
||||||
|
|
||||||
|
|
||||||
def output_transform_fn(x):
|
|
||||||
if isinstance(x, KeyedTensor):
|
|
||||||
output = dict()
|
|
||||||
for key in x.keys():
|
|
||||||
output[key] = x[key]
|
|
||||||
return output
|
|
||||||
else:
|
|
||||||
return dict(output=x)
|
|
||||||
|
|
||||||
|
|
||||||
def get_ebc():
|
def get_ebc():
|
||||||
# EmbeddingBagCollection
|
# EmbeddingBagCollection
|
||||||
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
|
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
|
||||||
|
Loading…
Reference in New Issue
Block a user