mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-23 12:36:03 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -26,9 +26,14 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
class ExperienceCompositionRefs:
|
||||
|
||||
def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef,
|
||||
base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
sequences_attention_mask_action_mask_ref: ray.ObjectRef,
|
||||
action_log_probs_ref: ray.ObjectRef,
|
||||
base_action_log_probs_ref: ray.ObjectRef,
|
||||
value_ref: ray.ObjectRef,
|
||||
r_ref: ray.ObjectRef,
|
||||
) -> None:
|
||||
self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref
|
||||
self.action_log_probs_ref = action_log_probs_ref
|
||||
self.base_action_log_probs_ref = base_action_log_probs_ref
|
||||
@@ -37,14 +42,14 @@ class ExperienceCompositionRefs:
|
||||
|
||||
|
||||
class ExperienceMaker:
|
||||
|
||||
def __init__(self, kl_coef) -> None:
|
||||
self.kl_coef = kl_coef
|
||||
|
||||
@torch.no_grad()
|
||||
def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs):
|
||||
sequences, attention_mask, action_mask = ray.get(
|
||||
experiment_computation_refs.sequences_attention_mask_action_mask_ref)
|
||||
experiment_computation_refs.sequences_attention_mask_action_mask_ref
|
||||
)
|
||||
action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref)
|
||||
base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref)
|
||||
r = ray.get(experiment_computation_refs.r_ref)
|
||||
@@ -58,11 +63,10 @@ class ExperienceMaker:
|
||||
|
||||
|
||||
class DistributedTorchRayActor:
|
||||
|
||||
def __init__(self, world_size, rank, local_rank, master_addr, master_port):
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
|
||||
level=logging.INFO,
|
||||
datefmt='%Y-%m-%d %H:%M:%S')
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
self._model = None
|
||||
self._world_size = world_size
|
||||
self._rank = rank
|
||||
@@ -82,7 +86,7 @@ class DistributedTorchRayActor:
|
||||
@staticmethod
|
||||
def _get_free_port():
|
||||
with socket.socket() as sock:
|
||||
sock.bind(('', 0))
|
||||
sock.bind(("", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
def get_master_addr_port(self):
|
||||
@@ -90,7 +94,6 @@ class DistributedTorchRayActor:
|
||||
|
||||
|
||||
class BasePPORole(DistributedTorchRayActor):
|
||||
|
||||
def add_experience_maker(self, kl_coef: float = 0.1):
|
||||
self._experience_maker = ExperienceMaker(kl_coef)
|
||||
|
||||
@@ -99,12 +102,12 @@ class BasePPORole(DistributedTorchRayActor):
|
||||
|
||||
def _init_strategy(self, strategy: str):
|
||||
# configure strategy
|
||||
if strategy == 'ddp':
|
||||
if strategy == "ddp":
|
||||
self._strategy = DDPStrategy()
|
||||
elif strategy == 'colossalai_gemini':
|
||||
self._strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2':
|
||||
self._strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif strategy == "colossalai_gemini":
|
||||
self._strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
elif strategy == "colossalai_zero2":
|
||||
self._strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
|
||||
@@ -124,11 +127,9 @@ class BasePPORole(DistributedTorchRayActor):
|
||||
def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def init_model_from_pretrained(self,
|
||||
strategy: str,
|
||||
model_class: Type[LoRAModule],
|
||||
pretrain: str,
|
||||
has_optimizer=False):
|
||||
def init_model_from_pretrained(
|
||||
self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer=False
|
||||
):
|
||||
self._init_strategy(strategy)
|
||||
self._load_model_from_pretrained(model_class, pretrain)
|
||||
self._prepare_model_with_strategy(has_optimizer)
|
||||
@@ -138,7 +139,6 @@ class BasePPORole(DistributedTorchRayActor):
|
||||
|
||||
|
||||
class TrainablePPORole(BasePPORole):
|
||||
|
||||
def _load_model_from_pretrained(self, model_class, pretrain):
|
||||
with self._strategy.model_init_context():
|
||||
self._model = model_class(pretrain).to(torch.cuda.current_device())
|
||||
@@ -161,38 +161,39 @@ class TrainablePPORole(BasePPORole):
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class RayPPOActor(TrainablePPORole):
|
||||
|
||||
def set_loss_function(self, eps_clip: float):
|
||||
self._actor_loss_fn = PolicyLoss(eps_clip)
|
||||
|
||||
def load_tokenizer_from_pretrained(self, model_type: str, pretrained):
|
||||
if model_type == 'gpt2':
|
||||
if model_type == "gpt2":
|
||||
self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained)
|
||||
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
|
||||
elif model_type == 'bloom':
|
||||
elif model_type == "bloom":
|
||||
self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained)
|
||||
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
|
||||
elif model_type == 'opt':
|
||||
elif model_type == "opt":
|
||||
self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{model_type}"')
|
||||
|
||||
# Set tokenize function for sequence generation
|
||||
def _text_input_tokenize_fn(texts):
|
||||
batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
|
||||
batch = self._model_tokenizer(texts, return_tensors="pt", max_length=96, padding=True, truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
self._sample_tokenize_function = _text_input_tokenize_fn
|
||||
|
||||
def setup_generate_kwargs(self, generate_kwargs: dict):
|
||||
from coati.trainer.ppo import _set_default_generate_kwargs
|
||||
|
||||
self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model)
|
||||
self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id
|
||||
self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id
|
||||
self._generate_kwargs["pad_token_id"] = self._model_tokenizer.pad_token_id
|
||||
self._generate_kwargs["eos_token_id"] = self._model_tokenizer.eos_token_id
|
||||
|
||||
def load_csv_prompt_file_from_url_to_sampler(self, prompt_url):
|
||||
import pandas as pd
|
||||
prompts = pd.read_csv(prompt_url)['prompt']
|
||||
|
||||
prompts = pd.read_csv(prompt_url)["prompt"]
|
||||
self._sampler = self._strategy.setup_sampler(prompts)
|
||||
|
||||
def _generate(self, input_ids, **generate_kwargs):
|
||||
@@ -214,10 +215,9 @@ class RayPPOActor(TrainablePPORole):
|
||||
def _training_step(self, experience):
|
||||
num_actions = experience.action_mask.size(1)
|
||||
action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
||||
actor_loss = self._actor_loss_fn(action_log_probs,
|
||||
experience.action_log_probs,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask)
|
||||
actor_loss = self._actor_loss_fn(
|
||||
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||
)
|
||||
self._strategy.backward(actor_loss, self._model, self._optimizer)
|
||||
self._strategy.optimizer_step(self._optimizer)
|
||||
self._optimizer.zero_grad()
|
||||
@@ -229,17 +229,18 @@ class RayPPOActor(TrainablePPORole):
|
||||
self._strategy.save_model(self._model, save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if should_save_optimizer:
|
||||
self._strategy.save_optimizer(self._optimizer,
|
||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
self._strategy.save_optimizer(
|
||||
self._optimizer,
|
||||
"actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()),
|
||||
only_rank0=False,
|
||||
)
|
||||
|
||||
def generate_answer(self, prompt, max_length=30, num_return_sequences=5):
|
||||
encoded_input = self._model_tokenizer(prompt, return_tensors='pt')
|
||||
encoded_input = self._model_tokenizer(prompt, return_tensors="pt")
|
||||
input_ids = {k: v.cuda() for k, v in encoded_input.items()}
|
||||
sequence, _ = self._model.generate(**input_ids,
|
||||
max_length=max_length,
|
||||
return_action_mask=False,
|
||||
num_return_sequences=num_return_sequences)
|
||||
sequence, _ = self._model.generate(
|
||||
**input_ids, max_length=max_length, return_action_mask=False, num_return_sequences=num_return_sequences
|
||||
)
|
||||
token_list = list(sequence.data[0])
|
||||
output = " ".join([self._model_tokenizer.decode(token) for token in token_list])
|
||||
return output
|
||||
@@ -247,18 +248,16 @@ class RayPPOActor(TrainablePPORole):
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class RayPPOCritic(TrainablePPORole):
|
||||
|
||||
def set_loss_function(self, value_clip: float):
|
||||
self._critic_loss_fn = ValueLoss(value_clip)
|
||||
|
||||
def _training_step(self, experience):
|
||||
values = self._model(experience.sequences,
|
||||
action_mask=experience.action_mask,
|
||||
attention_mask=experience.attention_mask)
|
||||
critic_loss = self._critic_loss_fn(values,
|
||||
experience.values,
|
||||
experience.reward,
|
||||
action_mask=experience.action_mask)
|
||||
values = self._model(
|
||||
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
|
||||
)
|
||||
critic_loss = self._critic_loss_fn(
|
||||
values, experience.values, experience.reward, action_mask=experience.action_mask
|
||||
)
|
||||
self._strategy.backward(critic_loss, self._model, self._optimizer)
|
||||
self._strategy.optimizer_step(self._optimizer)
|
||||
self._optimizer.zero_grad()
|
||||
@@ -272,12 +271,12 @@ class RayPPOCritic(TrainablePPORole):
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class RayPPORewardModel(BasePPORole):
|
||||
|
||||
def _load_model_from_pretrained(self, model_class, pretrain):
|
||||
with self._strategy.model_init_context():
|
||||
critic = model_class(pretrained=pretrain).to(torch.cuda.current_device())
|
||||
self._model = RewardModel(deepcopy(critic.model),
|
||||
deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
||||
self._model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(
|
||||
torch.cuda.current_device()
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_r(self, sequence_attention_action_mask):
|
||||
@@ -287,7 +286,6 @@ class RayPPORewardModel(BasePPORole):
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class RayPPOInitialModel(BasePPORole):
|
||||
|
||||
def _load_model_from_pretrained(self, model_class, pretrain):
|
||||
with self._strategy.model_init_context():
|
||||
self._model = model_class(pretrain).to(torch.cuda.current_device())
|
||||
@@ -300,8 +298,8 @@ class RayPPOInitialModel(BasePPORole):
|
||||
|
||||
class PPORayActorGroup:
|
||||
"""
|
||||
A group of ray actors
|
||||
Functions start with 'async' should return list of object refs
|
||||
A group of ray actors
|
||||
Functions start with 'async' should return list of object refs
|
||||
"""
|
||||
|
||||
def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None:
|
||||
@@ -319,8 +317,9 @@ class PPORayActorGroup:
|
||||
pg = placement_group(bundles, strategy="STRICT_SPREAD")
|
||||
ray.get(pg.ready())
|
||||
if pg:
|
||||
master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None)
|
||||
master_actor = self.ray_actor_type.options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg, placement_group_bundle_index=0)
|
||||
).remote(world_size, 0, 0, None, None)
|
||||
else:
|
||||
master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None)
|
||||
self._actor_handlers = [master_actor]
|
||||
@@ -331,16 +330,20 @@ class PPORayActorGroup:
|
||||
for rank in range(1, world_size):
|
||||
local_rank = rank % self._num_gpus_per_node
|
||||
if pg:
|
||||
worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote(
|
||||
world_size, rank, local_rank, master_addr, master_port)
|
||||
worker_actor = self.ray_actor_type.options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node
|
||||
)
|
||||
).remote(world_size, rank, local_rank, master_addr, master_port)
|
||||
else:
|
||||
worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank,
|
||||
master_addr, master_port)
|
||||
worker_actor = self.ray_actor_type.options(num_gpus=1).remote(
|
||||
world_size, rank, local_rank, master_addr, master_port
|
||||
)
|
||||
self._actor_handlers.append(worker_actor)
|
||||
|
||||
def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str,
|
||||
has_optimizer: bool):
|
||||
def async_init_model_from_pretrained(
|
||||
self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer: bool
|
||||
):
|
||||
return [
|
||||
actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer)
|
||||
for actor in self._actor_handlers
|
||||
@@ -348,7 +351,6 @@ class PPORayActorGroup:
|
||||
|
||||
|
||||
class TrainableModelRayActorGroup(PPORayActorGroup):
|
||||
|
||||
def async_learn_on_experiences(self, experience_refs):
|
||||
num_actors = len(self._actor_handlers)
|
||||
learn_result_refs = []
|
||||
@@ -359,7 +361,6 @@ class TrainableModelRayActorGroup(PPORayActorGroup):
|
||||
|
||||
|
||||
class PPOActorRayActorGroup(TrainableModelRayActorGroup):
|
||||
|
||||
def __init__(self, num_nodes, num_gpus_per_node) -> None:
|
||||
super().__init__(num_nodes, num_gpus_per_node, RayPPOActor)
|
||||
|
||||
@@ -381,7 +382,8 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup):
|
||||
action_log_probs_refs = []
|
||||
for i in range(len(sequences_attention_mask_action_mask_refs)):
|
||||
action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote(
|
||||
sequences_attention_mask_action_mask_refs[i])
|
||||
sequences_attention_mask_action_mask_refs[i]
|
||||
)
|
||||
action_log_probs_refs.append(action_log_probs_ref)
|
||||
return action_log_probs_refs
|
||||
|
||||
@@ -393,7 +395,6 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup):
|
||||
|
||||
|
||||
class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
|
||||
|
||||
def __init__(self, num_nodes, num_gpus_per_node) -> None:
|
||||
super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic)
|
||||
|
||||
@@ -402,7 +403,8 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
|
||||
value_refs = []
|
||||
for i in range(len(sequences_attention_mask_action_mask_refs)):
|
||||
value_ref = self._actor_handlers[i % num_actors].calculate_value.remote(
|
||||
sequences_attention_mask_action_mask_refs[i])
|
||||
sequences_attention_mask_action_mask_refs[i]
|
||||
)
|
||||
value_refs.append(value_ref)
|
||||
return value_refs
|
||||
|
||||
@@ -411,7 +413,6 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
|
||||
|
||||
|
||||
class PPOInitialRayActorGroup(PPORayActorGroup):
|
||||
|
||||
def __init__(self, num_nodes, num_gpus_per_node) -> None:
|
||||
super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel)
|
||||
|
||||
@@ -420,13 +421,13 @@ class PPOInitialRayActorGroup(PPORayActorGroup):
|
||||
base_action_log_probs_refs = []
|
||||
for i in range(len(sequences_attention_mask_action_mask_refs)):
|
||||
base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote(
|
||||
sequences_attention_mask_action_mask_refs[i])
|
||||
sequences_attention_mask_action_mask_refs[i]
|
||||
)
|
||||
base_action_log_probs_refs.append(base_action_log_probs_ref)
|
||||
return base_action_log_probs_refs
|
||||
|
||||
|
||||
class PPORewardRayActorGroup(PPORayActorGroup):
|
||||
|
||||
def __init__(self, num_nodes, num_gpus_per_node) -> None:
|
||||
super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel)
|
||||
|
||||
@@ -435,20 +436,21 @@ class PPORewardRayActorGroup(PPORayActorGroup):
|
||||
r_refs = []
|
||||
for i in range(len(sequences_attention_mask_action_mask_refs)):
|
||||
r_ref = self._actor_handlers[i % num_actors].calculate_r.remote(
|
||||
sequences_attention_mask_action_mask_refs[i])
|
||||
sequences_attention_mask_action_mask_refs[i]
|
||||
)
|
||||
r_refs.append(r_ref)
|
||||
return r_refs
|
||||
|
||||
|
||||
def main(args):
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
|
||||
level=logging.INFO,
|
||||
datefmt='%Y-%m-%d %H:%M:%S')
|
||||
if args.model == 'gpt2':
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
if args.model == "gpt2":
|
||||
actor_model_class, critic_model_class = GPTActor, GPTCritic
|
||||
elif args.model == 'bloom':
|
||||
elif args.model == "bloom":
|
||||
actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic
|
||||
elif args.model == 'opt':
|
||||
elif args.model == "opt":
|
||||
actor_model_class, critic_model_class = OPTActor, OPTCritic
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
@@ -462,13 +464,14 @@ def main(args):
|
||||
logging.info("Actors created")
|
||||
|
||||
# Prepare model for training
|
||||
generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50}
|
||||
generate_kwargs = {"max_length": 128, "do_sample": True, "temperature": 1.0, "top_k": 50}
|
||||
ray.get(
|
||||
actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) +
|
||||
critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) +
|
||||
initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) +
|
||||
reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) +
|
||||
actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs))
|
||||
actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True)
|
||||
+ critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True)
|
||||
+ initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False)
|
||||
+ reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False)
|
||||
+ actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs)
|
||||
)
|
||||
logging.info("Models prepared for training")
|
||||
|
||||
# Prepare models for training
|
||||
@@ -483,8 +486,12 @@ def main(args):
|
||||
# Start training
|
||||
logging.info("Training start")
|
||||
# Set all models to eval and add experience maker
|
||||
all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \
|
||||
initial_group._actor_handlers + reward_group._actor_handlers
|
||||
all_ray_actors = (
|
||||
actor_group._actor_handlers
|
||||
+ critic_group._actor_handlers
|
||||
+ initial_group._actor_handlers
|
||||
+ reward_group._actor_handlers
|
||||
)
|
||||
num_ray_actors = len(all_ray_actors)
|
||||
ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors])
|
||||
ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors])
|
||||
@@ -497,18 +504,28 @@ def main(args):
|
||||
time += 1
|
||||
# Experience queueing stage
|
||||
sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence(
|
||||
experience_batch_size)
|
||||
experience_batch_size
|
||||
)
|
||||
base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs(
|
||||
sequences_attention_mask_action_mask_refs)
|
||||
sequences_attention_mask_action_mask_refs
|
||||
)
|
||||
values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs)
|
||||
r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs)
|
||||
action_log_probs_refs = actor_group.async_calculate_action_log_probs(
|
||||
sequences_attention_mask_action_mask_refs)
|
||||
experience_composition_refs.extend([
|
||||
ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i],
|
||||
base_action_log_probs_refs[i], values_refs[i], r_refs[i])
|
||||
for i in range(len(sequences_attention_mask_action_mask_refs))
|
||||
])
|
||||
sequences_attention_mask_action_mask_refs
|
||||
)
|
||||
experience_composition_refs.extend(
|
||||
[
|
||||
ExperienceCompositionRefs(
|
||||
sequences_attention_mask_action_mask_refs[i],
|
||||
action_log_probs_refs[i],
|
||||
base_action_log_probs_refs[i],
|
||||
values_refs[i],
|
||||
r_refs[i],
|
||||
)
|
||||
for i in range(len(sequences_attention_mask_action_mask_refs))
|
||||
]
|
||||
)
|
||||
# Learning stage
|
||||
if time % update_timesteps == 0:
|
||||
experience_refs = []
|
||||
@@ -519,8 +536,9 @@ def main(args):
|
||||
experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref))
|
||||
# backward
|
||||
ray.get(
|
||||
actor_group.async_learn_on_experiences(experience_refs) +
|
||||
critic_group.async_learn_on_experiences(experience_refs))
|
||||
actor_group.async_learn_on_experiences(experience_refs)
|
||||
+ critic_group.async_learn_on_experiences(experience_refs)
|
||||
)
|
||||
# clear refs queue
|
||||
experience_composition_refs.clear()
|
||||
logging.info("Training finished")
|
||||
@@ -528,26 +546,24 @@ def main(args):
|
||||
actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--prompt_csv_url', type=str)
|
||||
parser.add_argument('--strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='ddp')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default='gpt2')
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1)
|
||||
parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1)
|
||||
parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1)
|
||||
parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1)
|
||||
parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1)
|
||||
parser.add_argument("--prompt_csv_url", type=str)
|
||||
parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
|
||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt"])
|
||||
parser.add_argument("--pretrain", type=str, default="gpt2")
|
||||
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts.pt")
|
||||
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||
parser.add_argument("--num_episodes", type=int, default=10)
|
||||
parser.add_argument("--max_timesteps", type=int, default=10)
|
||||
parser.add_argument("--update_timesteps", type=int, default=10)
|
||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--num_actor_nodes", type=int, help="num of nodes to use to host actor model", default=1)
|
||||
parser.add_argument("--num_critic_nodes", type=int, help="num of nodes to use to host critic model", default=1)
|
||||
parser.add_argument("--num_initial_nodes", type=int, help="num of nodes to use to host initial model", default=1)
|
||||
parser.add_argument("--num_reward_nodes", type=int, help="num of nodes to use to host reward model", default=1)
|
||||
parser.add_argument("--num_gpus_per_node", type=int, help="num of gpus on a ray node", default=1)
|
||||
args = parser.parse_args()
|
||||
ray.init()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user