mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[hotfix] fix typo change _descrption to _description (#5331)
This commit is contained in:
@@ -95,7 +95,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
||||
"""
|
||||
model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
|
||||
model_inputs = {"infer_state": self.mb_manager.cur_description.infer_state}
|
||||
return model_inputs
|
||||
|
||||
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
||||
@@ -107,7 +107,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||
Returns:
|
||||
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
|
||||
"""
|
||||
new_mask = self.mb_manager.cur_descrption.attn_mask
|
||||
new_mask = self.mb_manager.cur_description.attn_mask
|
||||
|
||||
return dict(input_ids=new_token, attention_mask=new_mask)
|
||||
|
||||
@@ -133,7 +133,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
|
||||
"""
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
self.mb_manager.add_description(inputs_dict)
|
||||
|
||||
def _load_stage_action(self, model: Module) -> None:
|
||||
"""
|
||||
@@ -141,7 +141,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||
1.load micro_batch 2.do the forward 3.step to update
|
||||
"""
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
self.mb_manager.add_description(inputs_dict)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
@@ -379,7 +379,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
self.mb_manager.add_description(inputs_dict)
|
||||
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
# In GENERATE phase
|
||||
@@ -415,7 +415,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||
inputs_dict = None
|
||||
if self.mb_manager.cur_state is Status.PREFILL:
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
self.mb_manager.add_description(inputs_dict)
|
||||
interval_inputs = {
|
||||
"hidden_states": hidden_states["hidden_states"],
|
||||
"infer_state": self.mb_manager.cur_infer_state,
|
||||
|
Reference in New Issue
Block a user