mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[pipeline] All bert models (#4233)
* bloom policy
* llama pipeline forward and tests
* fix the output and attention_mask
* fix name
* bind argument to policy
* Revert "bloom policy"
This reverts commit 8dee68a0a2
.
This policy should be revert and copied to feature/bloom
* revert the bloom changes
* cancel unneeded inputs
* gpt
* finish llama
* causal lm and sequence classification
* revision
* add pure pipeline test
* finish some bert models
* finish all bert models
* finish bert tests
* fix bugs
* fix bugs
* fix test pipeline
* fix data gen for qa
* update the set pipeline forward
* shared params
* fix bugs
This commit is contained in:
@@ -64,7 +64,10 @@ def _broadcast_object_list(object_list: List[Any],
|
||||
my_rank = dist.get_rank()
|
||||
# Serialize object_list elements to tensors on src rank.
|
||||
if my_rank == src:
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
|
||||
if torch.__version__ >= "1.13.0":
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=device) for obj in object_list])
|
||||
else:
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
|
||||
object_sizes_tensor = torch.cat(size_list)
|
||||
else:
|
||||
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
|
||||
|
@@ -205,7 +205,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
# the backward pass.
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
|
@@ -42,6 +42,8 @@ _POLICY_LIST = {
|
||||
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
|
||||
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForQuestionAnswering":
|
||||
PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"),
|
||||
|
||||
# LLaMA
|
||||
"transformers.models.llama.modeling_llama.LlamaModel":
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -212,11 +212,13 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
llama_model = self.model.model
|
||||
if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight):
|
||||
# tie weights
|
||||
return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}]
|
||||
return [{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user