[feat] add microbatch forwarding (#6251)

* add microbatch forwarding

* fix forward microbatch

* fix producer OOM

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change project name

* fix temperature annealing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address conversation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang
2025-03-28 10:24:58 +08:00
committed by GitHub
parent 489f215ad9
commit 50153005b4
5 changed files with 112 additions and 72 deletions

View File

@@ -34,6 +34,7 @@ def launch_distributed(
inference_microbatch_size: int,
train_batch_size: int,
train_microbatch_size: int,
train_minibatch_size: int,
dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
@@ -99,9 +100,13 @@ def launch_distributed(
batch_size=train_batch_size,
model_config=train_model_config,
plugin_config=plugin_config,
microbatch_size=train_microbatch_size,
microbatch_size=train_minibatch_size,
generate_config=generate_config_consumer,
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
training_config={
"filter_range": [0.05, 9.0],
"lr": 1e-6,
"train_microbatch_size": train_microbatch_size,
},
num_generations=num_generations,
)
procs.append(consumer)