[format] applied code formatting on changed files in pull request 4820 (#4886)

Co-authored-by: github-actions <github-actions@github.com>
This commit is contained in:
github-actions[bot]
2023-10-18 11:46:37 +08:00
committed by GitHub
parent c7aa319ba0
commit 486d06a2d5
13 changed files with 297 additions and 258 deletions

View File

@@ -167,7 +167,7 @@ def _p2p_comm(
group: ProcessGroup,
comm_dtype: torch.dtype = torch.float16,
):
"""
"""
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
Agrs:
@@ -176,7 +176,7 @@ def _p2p_comm(
peer (int): rank of the peer
group (ProcessGroup): process group
comm_dtype (torch.dtype): dtype of the tensor to be sent
Returns:
torch.Tensor: tensor received from previous stage
"""
@@ -302,7 +302,9 @@ class PipelineP2PCommunication:
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None:
def p2p_communicate(
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
) -> None:
"""
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
@@ -313,5 +315,7 @@ class PipelineP2PCommunication:
if peer is None:
peer = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype)
recv_tensor = _p2p_comm(
output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype
)
return recv_tensor

View File

@@ -1,6 +1,6 @@
import time
from functools import partial
from typing import Any, Iterable, List, Optional, Union
from typing import Any, Iterable, Optional, Union
import torch
import torch.cuda
@@ -16,7 +16,7 @@ from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule
class ActionIntervalBuffer():
class ActionIntervalBuffer:
"""
The buffer to save the interval hidden states and new token for stage to use.
@@ -70,8 +70,9 @@ class GenerateSchedule(PipelineSchedule):
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
assert self.batch_size % self.microbatch_size == 0, \
f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}"
assert (
self.batch_size % self.microbatch_size == 0
), f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}"
self.num_microbatches = self.batch_size // self.microbatch_size
self.round = self.num_microbatches // self.stage_manager.num_stages
@@ -86,26 +87,26 @@ class GenerateSchedule(PipelineSchedule):
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
def _prepare_inputs_for_interval_stage(self):
'''
"""
Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values
Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
'''
model_inputs = {
'past_key_values': self.mb_manager.cur_kv_cache
} if self.mb_manager.cur_kv_cache is not None else None
"""
model_inputs = (
{"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None
)
return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
'''
"""
Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values`
`input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end,
`past_key_values` is the past_key_values save in the micro batch manager
Returns:
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
'''
"""
new_mask = self.mb_manager.cur_descrption.attn_mask
past_key_values = self.mb_manager.cur_descrption.kv_cache
@@ -117,12 +118,12 @@ class GenerateSchedule(PipelineSchedule):
return input_ids
def _recv_pre_stage(self) -> Any:
'''
"""
Receive the output from previous stage
Returns:
Any: The output from previous stage
'''
"""
if self.stage_manager.num_stages == 2:
return self.comm.p2p_recv()
return self.comm.recv_forward()
@@ -138,7 +139,7 @@ class GenerateSchedule(PipelineSchedule):
output_dict = model_forward(model, inputs_dict, None)
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _gen_token_action(self, model: Module):
"""
@@ -146,13 +147,15 @@ class GenerateSchedule(PipelineSchedule):
"""
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
hidden_states = {'hidden_states': hidden_states}
hidden_states = {"hidden_states": hidden_states}
logits = model_forward(model, None, hidden_states)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits['logits'])
assert (
"logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(None, None, new_token)
self.action_interval_buffer.new_token = new_token
@@ -168,17 +171,17 @@ class GenerateSchedule(PipelineSchedule):
output_dict = model_forward(model, inputs_dict, None)
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
inputs_dict = self._prepare_inputs_for_interval_stage()
hidden_states = {'hidden_states': hidden_states}
hidden_states = {"hidden_states": hidden_states}
output_dict = model_forward(model, inputs_dict, hidden_states)
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
"""
@@ -246,10 +249,13 @@ class GenerateSchedule(PipelineSchedule):
whole_timestamp = []
#run by round
# run by round
for _ in range(self.round):
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)
] if self.verbose and self.stage_manager.is_first_stage() else None
self.timestamps = (
[[] for _ in range(self.stage_manager.num_stages)]
if self.verbose and self.stage_manager.is_first_stage()
else None
)
self.action_interval_buffer.clear()
while self.mb_manager.is_micro_batch_done() is False:
actions = self._gen_action(model)
@@ -286,8 +292,11 @@ class GenerateSchedule(PipelineSchedule):
whole_timestamp = []
# run by round
for _ in range(self.round):
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)
] if self.verbose and self.stage_manager.is_first_stage() else None
self.timestamps = (
[[] for _ in range(self.stage_manager.num_stages)]
if self.verbose and self.stage_manager.is_first_stage()
else None
)
while self.mb_manager.is_micro_batch_done() is False:
inputs_dict = None
new_token = None
@@ -307,13 +316,17 @@ class GenerateSchedule(PipelineSchedule):
hidden_states = self.comm.recv_forward()
if self.stage_manager.is_first_stage():
# First just generate a new token
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
assert (
hidden_states is not None
), "When first stage in GENERATE phase, the hidden states should not be None"
logits = model_forward(model, None, hidden_states)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits['logits'])
assert (
"logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(None, None, new_token)
# If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
@@ -327,9 +340,11 @@ class GenerateSchedule(PipelineSchedule):
self.mb_manager.step(inputs_dict, output_dict, None)
# Current microbatch is not DONE, send hidden_state to next stage
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE,
Status.COOLDOWN):
self.comm.send_forward({'hidden_states': output_dict['hidden_states']})
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (
Status.GENERATE,
Status.COOLDOWN,
):
self.comm.send_forward({"hidden_states": output_dict["hidden_states"]})
self.mb_manager.next()