mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-02 16:28:58 +00:00
add algo selection
This commit is contained in:
parent
812f4b7750
commit
0f566cc2d4
@ -2,9 +2,15 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
|
||||
from .consumer import SimpleConsumer
|
||||
from .grpo_consumer import GRPOConsumer
|
||||
from .producer import SimpleProducer
|
||||
|
||||
ALGO_MAP = {
|
||||
"Simple": SimpleConsumer,
|
||||
"GRPO": GRPOConsumer,
|
||||
}
|
||||
|
||||
|
||||
def get_jsonl_size_fast(path: str) -> int:
|
||||
with open(path) as f:
|
||||
@ -40,7 +46,14 @@ def launch_distributed(
|
||||
inference_backend: str = "transformers",
|
||||
master_addr: str = "localhost",
|
||||
master_port: int = 29500,
|
||||
core_algo: str = "GRPO",
|
||||
):
|
||||
|
||||
if core_algo not in ALGO_MAP:
|
||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||
else:
|
||||
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
||||
|
||||
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
|
||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||
|
||||
@ -68,7 +81,7 @@ def launch_distributed(
|
||||
)
|
||||
procs.append(producer)
|
||||
for i in range(num_consumer_procs):
|
||||
consumer = GRPOConsumer.options(num_gpus=1).remote(
|
||||
consumer = core_consumer.options(num_gpus=1).remote(
|
||||
num_producers=num_producers,
|
||||
num_episodes=num_episodes,
|
||||
rank=i,
|
||||
|
@ -15,6 +15,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
|
||||
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
||||
parser.add_argument("-a", "--algo", type=str, default="GPRO", choices=["Simple, GPRO"])
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(address="local", namespace="ray-example")
|
||||
@ -95,4 +96,5 @@ if __name__ == "__main__":
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=29504,
|
||||
core_algo=args.algo
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user