mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[test] fix gemini checkpoint and gpt test (#4620)
This commit is contained in:
@@ -32,7 +32,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
|
||||
elif plugin_type == 'zero':
|
||||
plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32)
|
||||
elif plugin_type == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy='cuda', precision="fp16", initial_scale=32)
|
||||
plugin = GeminiPlugin(precision="fp16", initial_scale=32)
|
||||
else:
|
||||
raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.")
|
||||
|
||||
|
Reference in New Issue
Block a user