mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
[feat] add microbatch forwarding (#6251)
* add microbatch forwarding * fix forward microbatch * fix producer OOM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change project name * fix temperature annealing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address conversation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -34,6 +34,7 @@ def launch_distributed(
|
||||
inference_microbatch_size: int,
|
||||
train_batch_size: int,
|
||||
train_microbatch_size: int,
|
||||
train_minibatch_size: int,
|
||||
dataset_config: Dict[str, Any],
|
||||
dataloaders_config: Dict[str, Any],
|
||||
inference_model_config: Dict[str, Any],
|
||||
@@ -99,9 +100,13 @@ def launch_distributed(
|
||||
batch_size=train_batch_size,
|
||||
model_config=train_model_config,
|
||||
plugin_config=plugin_config,
|
||||
microbatch_size=train_microbatch_size,
|
||||
microbatch_size=train_minibatch_size,
|
||||
generate_config=generate_config_consumer,
|
||||
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
|
||||
training_config={
|
||||
"filter_range": [0.05, 9.0],
|
||||
"lr": 1e-6,
|
||||
"train_microbatch_size": train_microbatch_size,
|
||||
},
|
||||
num_generations=num_generations,
|
||||
)
|
||||
procs.append(consumer)
|
||||
|
Reference in New Issue
Block a user