mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-24 10:41:07 +00:00
16 lines
421 B
Python
16 lines
421 B
Python
from typing import Any
|
|
|
|
import torch
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
|
|
|
__all__ = ['GeminiAdamOptimizer']
|
|
|
|
|
|
class GeminiAdamOptimizer(ZeroOptimizer):
|
|
|
|
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
|
|
optimizer = HybridAdam(model.parameters(), **defaults)
|
|
super().__init__(optimizer, model, **defaults)
|