mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-03 18:46:43 +00:00
fix transformers backend
This commit is contained in:
parent
57b49da5e4
commit
bc0171d392
1
.gitignore
vendored
1
.gitignore
vendored
@ -163,3 +163,4 @@ coverage.xml
|
||||
# log, test files - ColossalChat
|
||||
applications/ColossalChat/logs
|
||||
applications/ColossalChat/tests/logs
|
||||
applications/ColossalChat/wandb
|
||||
|
@ -61,12 +61,22 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
self.generate_config = generate_config.copy()
|
||||
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = 8
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
micro_batch_size = input_ids.size(0)
|
||||
input_ids = input_ids.to(get_current_device())
|
||||
attention_mask = attention_mask.to(get_current_device())
|
||||
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
|
||||
gt_answer = None
|
||||
if "gt_answer" in kwargs:
|
||||
gt_answer = kwargs.pop("gt_answer")
|
||||
if self.num_generations > 1:
|
||||
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
|
||||
out = self.model.generate(
|
||||
input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer
|
||||
)
|
||||
input_len = input_ids.shape[-1]
|
||||
new_token_ids = out.sequences[:, input_len:]
|
||||
# get log probs
|
||||
@ -76,10 +86,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
|
||||
action_log_probs = torch.cat(action_log_probs, dim=1)
|
||||
# get action mask
|
||||
response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device())
|
||||
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
|
||||
if self.tokenizer.eos_token_id is not None:
|
||||
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
|
||||
action_mask[indices[0], indices[1] + 1 :] = 0
|
||||
response_idx[:, 0] = input_len
|
||||
response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1
|
||||
|
||||
if attention_mask.size(0) != action_mask.size(0):
|
||||
assert action_mask.size(0) % attention_mask.size(0) == 0
|
||||
@ -91,7 +104,15 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
"attention_mask": attention_mask,
|
||||
"action_log_probs": action_log_probs,
|
||||
"action_mask": action_mask,
|
||||
"response_idx": response_idx,
|
||||
}
|
||||
|
||||
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
|
||||
|
||||
if gt_answer is not None:
|
||||
# repeat gt_answer for each prompt.
|
||||
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
|
||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||
return data
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
|
@ -10,9 +10,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=32)
|
||||
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16)
|
||||
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
|
||||
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
|
||||
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
|
||||
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
|
||||
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"])
|
||||
@ -24,29 +24,31 @@ if __name__ == "__main__":
|
||||
train_model_config = dict(path=args.model)
|
||||
generate_config = dict(
|
||||
top_k=50,
|
||||
top_p=0.8,
|
||||
top_p=0.9,
|
||||
temperature=1.0,
|
||||
)
|
||||
|
||||
if args.backend == "transformers":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
attn_implementation="flash_attention_2",
|
||||
use_flash_attention_2=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
)
|
||||
train_model_config.update(
|
||||
dict(
|
||||
attn_implementation="flash_attention_2",
|
||||
use_flash_attention_2=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_cache=False,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_length=512,
|
||||
max_length=1024 + 512,
|
||||
do_sample=True,
|
||||
max_new_tokens=None,
|
||||
early_stopping=False,
|
||||
stop_strings=["</answer>"],
|
||||
)
|
||||
)
|
||||
elif args.backend == "vllm":
|
||||
@ -82,12 +84,12 @@ if __name__ == "__main__":
|
||||
num_producers=args.num_inferencer,
|
||||
num_proc_per_producer=1,
|
||||
num_consumer_procs=args.num_trainers,
|
||||
num_episodes=1,
|
||||
num_episodes=10,
|
||||
inference_batch_size=args.inference_batch_size,
|
||||
inference_microbatch_size=args.inference_microbatch_size,
|
||||
train_batch_size=args.train_batch_size,
|
||||
train_microbatch_size=args.train_microbatch_size,
|
||||
dataset_config={"path": args.dataset, "max_length": 256},
|
||||
dataset_config={"path": args.dataset, "max_length": 300},
|
||||
dataloaders_config={},
|
||||
inference_model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
|
Loading…
Reference in New Issue
Block a user