[chat] refactor actor class (#3968)

* refactor: separate log_probs fn from Actor forward fn

* refactor: separate generate fn from Actor class

* feat: update unwrap_model and get_base_model
* unwrap_model returns model not wrapped by Strategy
* get_base_model returns HF model for Actor, Critic and RewardModel

* feat: simplify Strategy.prepare

* style: remove get_base_model method of Actor

* perf: tokenize text in batches

* refactor: move calc_action_log_probs to utils of model

* test: update test with new forward fn

* style: rename forward fn args

* fix: do not unwrap model in save_model fn of naive strategy

* test: add gemini test for train_prompts

* fix: fix _set_default_generate_kwargs
This commit is contained in:
Wenhao Chen
2023-06-13 13:31:56 +08:00
committed by GitHub
parent b3ab7fbabf
commit 9d02590c9a
14 changed files with 151 additions and 120 deletions

View File

@@ -35,14 +35,14 @@ class PromptDataset(Dataset):
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size]
for data_dict in list_data_dict:
token = tokenizer(data_dict["instruction"],
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)
for k, tensor in token.items():
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
tokens = tokenizer(instructions,
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)
for k, tensor in tokens.items():
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
def __len__(self):
return len(self.keyed_prompt["input_ids"])

View File

@@ -74,21 +74,18 @@ class SFTDataset(Dataset):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
def _tokenize_fn(strings: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
max_length: int
) -> Dict[str, torch.Tensor]:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
tokenized_list = tokenizer(
strings, return_tensors="pt", padding="longest",
max_length=max_length, truncation=True
)
input_ids = labels = tokenized_list["input_ids"]
input_ids_lens = labels_lens = \
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
return dict(
input_ids=input_ids,
labels=labels,
@@ -105,7 +102,10 @@ def preprocess(
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer, max_length)
for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):