mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[format] applied code formatting on changed files in pull request 4820 (#4886)
Co-authored-by: github-actions <github-actions@github.com>
This commit is contained in:
committed by
GitHub
parent
c7aa319ba0
commit
486d06a2d5
@@ -167,7 +167,7 @@ def _p2p_comm(
|
||||
group: ProcessGroup,
|
||||
comm_dtype: torch.dtype = torch.float16,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
|
||||
|
||||
Agrs:
|
||||
@@ -176,7 +176,7 @@ def _p2p_comm(
|
||||
peer (int): rank of the peer
|
||||
group (ProcessGroup): process group
|
||||
comm_dtype (torch.dtype): dtype of the tensor to be sent
|
||||
|
||||
|
||||
Returns:
|
||||
torch.Tensor: tensor received from previous stage
|
||||
"""
|
||||
@@ -302,7 +302,9 @@ class PipelineP2PCommunication:
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
||||
|
||||
def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None:
|
||||
def p2p_communicate(
|
||||
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
|
||||
) -> None:
|
||||
"""
|
||||
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
||||
|
||||
@@ -313,5 +315,7 @@ class PipelineP2PCommunication:
|
||||
if peer is None:
|
||||
peer = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype)
|
||||
recv_tensor = _p2p_comm(
|
||||
output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype
|
||||
)
|
||||
return recv_tensor
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
@@ -16,7 +16,7 @@ from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
|
||||
from .base import PipelineSchedule
|
||||
|
||||
|
||||
class ActionIntervalBuffer():
|
||||
class ActionIntervalBuffer:
|
||||
"""
|
||||
The buffer to save the interval hidden states and new token for stage to use.
|
||||
|
||||
@@ -70,8 +70,9 @@ class GenerateSchedule(PipelineSchedule):
|
||||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
self.microbatch_offset = 0
|
||||
assert self.batch_size % self.microbatch_size == 0, \
|
||||
f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}"
|
||||
assert (
|
||||
self.batch_size % self.microbatch_size == 0
|
||||
), f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}"
|
||||
self.num_microbatches = self.batch_size // self.microbatch_size
|
||||
self.round = self.num_microbatches // self.stage_manager.num_stages
|
||||
|
||||
@@ -86,26 +87,26 @@ class GenerateSchedule(PipelineSchedule):
|
||||
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
|
||||
|
||||
def _prepare_inputs_for_interval_stage(self):
|
||||
'''
|
||||
"""
|
||||
Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values
|
||||
|
||||
Returns:
|
||||
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
||||
'''
|
||||
model_inputs = {
|
||||
'past_key_values': self.mb_manager.cur_kv_cache
|
||||
} if self.mb_manager.cur_kv_cache is not None else None
|
||||
"""
|
||||
model_inputs = (
|
||||
{"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
||||
'''
|
||||
"""
|
||||
Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values`
|
||||
`input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end,
|
||||
`past_key_values` is the past_key_values save in the micro batch manager
|
||||
|
||||
Returns:
|
||||
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
|
||||
'''
|
||||
"""
|
||||
new_mask = self.mb_manager.cur_descrption.attn_mask
|
||||
past_key_values = self.mb_manager.cur_descrption.kv_cache
|
||||
|
||||
@@ -117,12 +118,12 @@ class GenerateSchedule(PipelineSchedule):
|
||||
return input_ids
|
||||
|
||||
def _recv_pre_stage(self) -> Any:
|
||||
'''
|
||||
"""
|
||||
Receive the output from previous stage
|
||||
|
||||
Returns:
|
||||
Any: The output from previous stage
|
||||
'''
|
||||
"""
|
||||
if self.stage_manager.num_stages == 2:
|
||||
return self.comm.p2p_recv()
|
||||
return self.comm.recv_forward()
|
||||
@@ -138,7 +139,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
|
||||
def _gen_token_action(self, model: Module):
|
||||
"""
|
||||
@@ -146,13 +147,15 @@ class GenerateSchedule(PipelineSchedule):
|
||||
"""
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
hidden_states = {'hidden_states': hidden_states}
|
||||
hidden_states = {"hidden_states": hidden_states}
|
||||
logits = model_forward(model, None, hidden_states)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits['logits'])
|
||||
assert (
|
||||
"logits" in logits
|
||||
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits["logits"])
|
||||
|
||||
self.mb_manager.step(None, None, new_token)
|
||||
self.action_interval_buffer.new_token = new_token
|
||||
@@ -168,17 +171,17 @@ class GenerateSchedule(PipelineSchedule):
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
|
||||
def _body_encoding_action(self, model: Module):
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_interval_stage()
|
||||
hidden_states = {'hidden_states': hidden_states}
|
||||
hidden_states = {"hidden_states": hidden_states}
|
||||
output_dict = model_forward(model, inputs_dict, hidden_states)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
|
||||
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
|
||||
"""
|
||||
@@ -246,10 +249,13 @@ class GenerateSchedule(PipelineSchedule):
|
||||
|
||||
whole_timestamp = []
|
||||
|
||||
#run by round
|
||||
# run by round
|
||||
for _ in range(self.round):
|
||||
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)
|
||||
] if self.verbose and self.stage_manager.is_first_stage() else None
|
||||
self.timestamps = (
|
||||
[[] for _ in range(self.stage_manager.num_stages)]
|
||||
if self.verbose and self.stage_manager.is_first_stage()
|
||||
else None
|
||||
)
|
||||
self.action_interval_buffer.clear()
|
||||
while self.mb_manager.is_micro_batch_done() is False:
|
||||
actions = self._gen_action(model)
|
||||
@@ -286,8 +292,11 @@ class GenerateSchedule(PipelineSchedule):
|
||||
whole_timestamp = []
|
||||
# run by round
|
||||
for _ in range(self.round):
|
||||
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)
|
||||
] if self.verbose and self.stage_manager.is_first_stage() else None
|
||||
self.timestamps = (
|
||||
[[] for _ in range(self.stage_manager.num_stages)]
|
||||
if self.verbose and self.stage_manager.is_first_stage()
|
||||
else None
|
||||
)
|
||||
while self.mb_manager.is_micro_batch_done() is False:
|
||||
inputs_dict = None
|
||||
new_token = None
|
||||
@@ -307,13 +316,17 @@ class GenerateSchedule(PipelineSchedule):
|
||||
hidden_states = self.comm.recv_forward()
|
||||
if self.stage_manager.is_first_stage():
|
||||
# First just generate a new token
|
||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
assert (
|
||||
hidden_states is not None
|
||||
), "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
logits = model_forward(model, None, hidden_states)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits['logits'])
|
||||
assert (
|
||||
"logits" in logits
|
||||
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits["logits"])
|
||||
self.mb_manager.step(None, None, new_token)
|
||||
# If the current micro batch is not DONE, go through blocks
|
||||
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
|
||||
@@ -327,9 +340,11 @@ class GenerateSchedule(PipelineSchedule):
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
|
||||
# Current microbatch is not DONE, send hidden_state to next stage
|
||||
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE,
|
||||
Status.COOLDOWN):
|
||||
self.comm.send_forward({'hidden_states': output_dict['hidden_states']})
|
||||
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (
|
||||
Status.GENERATE,
|
||||
Status.COOLDOWN,
|
||||
):
|
||||
self.comm.send_forward({"hidden_states": output_dict["hidden_states"]})
|
||||
|
||||
self.mb_manager.next()
|
||||
|
||||
|
Reference in New Issue
Block a user