mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
parent
3c6b831c26
commit
079bf3cb26
22
.flake8
22
.flake8
@ -1,22 +0,0 @@
|
|||||||
[flake8]
|
|
||||||
ignore =
|
|
||||||
;W503 line break before binary operator
|
|
||||||
W503,
|
|
||||||
;E203 whitespace before ':'
|
|
||||||
E203,
|
|
||||||
|
|
||||||
; exclude file
|
|
||||||
exclude =
|
|
||||||
.tox,
|
|
||||||
.git,
|
|
||||||
__pycache__,
|
|
||||||
build,
|
|
||||||
dist,
|
|
||||||
*.pyc,
|
|
||||||
*.egg-info,
|
|
||||||
.cache,
|
|
||||||
.eggs
|
|
||||||
|
|
||||||
max-line-length = 120
|
|
||||||
|
|
||||||
per-file-ignores = __init__.py:F401
|
|
12
.github/workflows/scripts/check_doc_i18n.py
vendored
12
.github/workflows/scripts/check_doc_i18n.py
vendored
@ -22,13 +22,13 @@ def compare_dirs(dir1, dir2):
|
|||||||
|
|
||||||
# If the corresponding item doesn't exist in the second directory, the directories are different
|
# If the corresponding item doesn't exist in the second directory, the directories are different
|
||||||
if not os.path.exists(item_path2):
|
if not os.path.exists(item_path2):
|
||||||
print(f'Found mismatch: {item_path1}, {item_path2}')
|
print(f"Found mismatch: {item_path1}, {item_path2}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# If the corresponding item is a directory, we compare the two directories recursively
|
# If the corresponding item is a directory, we compare the two directories recursively
|
||||||
if os.path.isdir(item_path1) and os.path.isdir(item_path2):
|
if os.path.isdir(item_path1) and os.path.isdir(item_path2):
|
||||||
if not compare_dirs(item_path1, item_path2):
|
if not compare_dirs(item_path1, item_path2):
|
||||||
print(f'Found mismatch: {item_path1}, {item_path2}')
|
print(f"Found mismatch: {item_path1}, {item_path2}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# both are files
|
# both are files
|
||||||
@ -37,16 +37,16 @@ def compare_dirs(dir1, dir2):
|
|||||||
|
|
||||||
# If the corresponding item is not a file or a directory, the directories are different
|
# If the corresponding item is not a file or a directory, the directories are different
|
||||||
else:
|
else:
|
||||||
print(f'Found mismatch: {item_path1}, {item_path2}')
|
print(f"Found mismatch: {item_path1}, {item_path2}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# If all items are the same, the directories are the same
|
# If all items are the same, the directories are the same
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-d', '--directory', help="The directory where the multi-language source files are kept.")
|
parser.add_argument("-d", "--directory", help="The directory where the multi-language source files are kept.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
i18n_folders = os.listdir(args.directory)
|
i18n_folders = os.listdir(args.directory)
|
||||||
@ -56,7 +56,7 @@ if __name__ == '__main__':
|
|||||||
for i in range(1, len(i18n_folders)):
|
for i in range(1, len(i18n_folders)):
|
||||||
dir1 = i18n_folders[0]
|
dir1 = i18n_folders[0]
|
||||||
dir2 = i18n_folders[i]
|
dir2 = i18n_folders[i]
|
||||||
print(f'comparing {dir1} vs {dir2}')
|
print(f"comparing {dir1} vs {dir2}")
|
||||||
match = compare_dirs(i18n_folders[0], i18n_folders[i])
|
match = compare_dirs(i18n_folders[0], i18n_folders[i])
|
||||||
|
|
||||||
if not match:
|
if not match:
|
||||||
|
@ -4,7 +4,7 @@ import os
|
|||||||
|
|
||||||
def check_inputs(input_list):
|
def check_inputs(input_list):
|
||||||
for path in input_list:
|
for path in input_list:
|
||||||
real_path = os.path.join('examples', path)
|
real_path = os.path.join("examples", path)
|
||||||
if not os.path.exists(real_path):
|
if not os.path.exists(real_path):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@ -12,16 +12,16 @@ def check_inputs(input_list):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-f', '--fileNameList', type=str, help="List of file names")
|
parser.add_argument("-f", "--fileNameList", type=str, help="List of file names")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
name_list = args.fileNameList.split(",")
|
name_list = args.fileNameList.split(",")
|
||||||
is_correct = check_inputs(name_list)
|
is_correct = check_inputs(name_list)
|
||||||
|
|
||||||
if is_correct:
|
if is_correct:
|
||||||
print('success')
|
print("success")
|
||||||
else:
|
else:
|
||||||
print('failure')
|
print("failure")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -17,21 +17,21 @@ def show_files(path, all_files):
|
|||||||
|
|
||||||
|
|
||||||
def join(input_list, sep=None):
|
def join(input_list, sep=None):
|
||||||
return (sep or ' ').join(input_list)
|
return (sep or " ").join(input_list)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
contents = show_files('examples/', [])
|
contents = show_files("examples/", [])
|
||||||
all_loc = []
|
all_loc = []
|
||||||
for file_loc in contents:
|
for file_loc in contents:
|
||||||
split_loc = file_loc.split('/')
|
split_loc = file_loc.split("/")
|
||||||
# must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not.
|
# must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not.
|
||||||
if len(split_loc) >= 4:
|
if len(split_loc) >= 4:
|
||||||
re_loc = '/'.join(split_loc[1:3])
|
re_loc = "/".join(split_loc[1:3])
|
||||||
if re_loc not in all_loc:
|
if re_loc not in all_loc:
|
||||||
all_loc.append(re_loc)
|
all_loc.append(re_loc)
|
||||||
print(all_loc)
|
print(all_loc)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -3,7 +3,7 @@ import argparse
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files")
|
parser.add_argument("-f", "--fileNameList", type=str, help="The list of changed files")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
name_list = args.fileNameList.split(":")
|
name_list = args.fileNameList.split(":")
|
||||||
folder_need_check = set()
|
folder_need_check = set()
|
||||||
@ -15,10 +15,10 @@ def main():
|
|||||||
# - application
|
# - application
|
||||||
# - file
|
# - file
|
||||||
if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4:
|
if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4:
|
||||||
folder_need_check.add('/'.join(loc.split("/")[1:3]))
|
folder_need_check.add("/".join(loc.split("/")[1:3]))
|
||||||
# Output the result using print. Then the shell can get the values.
|
# Output the result using print. Then the shell can get the values.
|
||||||
print(list(folder_need_check))
|
print(list(folder_need_check))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -74,16 +74,16 @@ def get_organization_repositories(github_token, organization_name) -> List[str]:
|
|||||||
|
|
||||||
# prepare header
|
# prepare header
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': f'Bearer {github_token}',
|
"Authorization": f"Bearer {github_token}",
|
||||||
'Accept': 'application/vnd.github+json',
|
"Accept": "application/vnd.github+json",
|
||||||
'X-GitHub-Api-Version': '2022-11-28'
|
"X-GitHub-Api-Version": "2022-11-28",
|
||||||
}
|
}
|
||||||
|
|
||||||
res = requests.get(url, headers=headers).json()
|
res = requests.get(url, headers=headers).json()
|
||||||
repo_list = []
|
repo_list = []
|
||||||
|
|
||||||
for item in res:
|
for item in res:
|
||||||
repo_list.append(item['name'])
|
repo_list.append(item["name"])
|
||||||
return repo_list
|
return repo_list
|
||||||
|
|
||||||
|
|
||||||
@ -97,9 +97,9 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name:
|
|||||||
"""
|
"""
|
||||||
# prepare header
|
# prepare header
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': f'Bearer {github_token}',
|
"Authorization": f"Bearer {github_token}",
|
||||||
'Accept': 'application/vnd.github+json',
|
"Accept": "application/vnd.github+json",
|
||||||
'X-GitHub-Api-Version': '2022-11-28'
|
"X-GitHub-Api-Version": "2022-11-28",
|
||||||
}
|
}
|
||||||
|
|
||||||
user_engagement_count = {}
|
user_engagement_count = {}
|
||||||
@ -107,28 +107,28 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name:
|
|||||||
# do pagination to the API
|
# do pagination to the API
|
||||||
page = 1
|
page = 1
|
||||||
while True:
|
while True:
|
||||||
comment_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}'
|
comment_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}"
|
||||||
comment_response = requests.get(comment_api, headers=headers).json()
|
comment_response = requests.get(comment_api, headers=headers).json()
|
||||||
|
|
||||||
if len(comment_response) == 0:
|
if len(comment_response) == 0:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for item in comment_response:
|
for item in comment_response:
|
||||||
comment_author_relationship = item['author_association']
|
comment_author_relationship = item["author_association"]
|
||||||
if comment_author_relationship != 'MEMBER':
|
if comment_author_relationship != "MEMBER":
|
||||||
# if the comment is not made by our member
|
# if the comment is not made by our member
|
||||||
# we don't count this comment towards user engagement
|
# we don't count this comment towards user engagement
|
||||||
continue
|
continue
|
||||||
|
|
||||||
issue_id = item['issue_url'].split('/')[-1]
|
issue_id = item["issue_url"].split("/")[-1]
|
||||||
issue_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}'
|
issue_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}"
|
||||||
issue_response = requests.get(issue_api, headers=headers).json()
|
issue_response = requests.get(issue_api, headers=headers).json()
|
||||||
issue_author_relationship = issue_response['author_association']
|
issue_author_relationship = issue_response["author_association"]
|
||||||
|
|
||||||
if issue_author_relationship != 'MEMBER':
|
if issue_author_relationship != "MEMBER":
|
||||||
# this means that the issue/PR is not created by our own people
|
# this means that the issue/PR is not created by our own people
|
||||||
# any comments in this issue/PR by our member will be counted towards the leaderboard
|
# any comments in this issue/PR by our member will be counted towards the leaderboard
|
||||||
member_name = item['user']['login']
|
member_name = item["user"]["login"]
|
||||||
|
|
||||||
if member_name in user_engagement_count:
|
if member_name in user_engagement_count:
|
||||||
user_engagement_count[member_name] += 1
|
user_engagement_count[member_name] += 1
|
||||||
@ -153,7 +153,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
|||||||
if cursor is None:
|
if cursor is None:
|
||||||
offset_str = ""
|
offset_str = ""
|
||||||
else:
|
else:
|
||||||
offset_str = f", after: \"{cursor}\""
|
offset_str = f', after: "{cursor}"'
|
||||||
query = f"""
|
query = f"""
|
||||||
{{
|
{{
|
||||||
repository(owner: "{org_name}", name: "{repo_name}"){{
|
repository(owner: "{org_name}", name: "{repo_name}"){{
|
||||||
@ -182,7 +182,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
|||||||
if cursor is None:
|
if cursor is None:
|
||||||
offset_str = ""
|
offset_str = ""
|
||||||
else:
|
else:
|
||||||
offset_str = f", before: \"{cursor}\""
|
offset_str = f', before: "{cursor}"'
|
||||||
query = f"""
|
query = f"""
|
||||||
{{
|
{{
|
||||||
repository(owner: "{org_name}", name: "{repo_name}"){{
|
repository(owner: "{org_name}", name: "{repo_name}"){{
|
||||||
@ -220,8 +220,8 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
|||||||
# a utility function to make call to Github GraphQL API
|
# a utility function to make call to Github GraphQL API
|
||||||
def _call_graphql_api(query):
|
def _call_graphql_api(query):
|
||||||
headers = {"Authorization": f"Bearer {github_token}"}
|
headers = {"Authorization": f"Bearer {github_token}"}
|
||||||
json_data = {'query': query}
|
json_data = {"query": query}
|
||||||
response = requests.post('https://api.github.com/graphql', json=json_data, headers=headers)
|
response = requests.post("https://api.github.com/graphql", json=json_data, headers=headers)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@ -234,21 +234,21 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
|||||||
data = _call_graphql_api(query)
|
data = _call_graphql_api(query)
|
||||||
found_discussion_out_of_time_range = False
|
found_discussion_out_of_time_range = False
|
||||||
|
|
||||||
edges = data['data']['repository']['discussions']['edges']
|
edges = data["data"]["repository"]["discussions"]["edges"]
|
||||||
if len(edges) == 0:
|
if len(edges) == 0:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# keep the discussion whose author is not a member
|
# keep the discussion whose author is not a member
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
# print the discussion title
|
# print the discussion title
|
||||||
discussion = edge['node']
|
discussion = edge["node"]
|
||||||
discussion_updated_at = str2datetime(discussion['updatedAt'])
|
discussion_updated_at = str2datetime(discussion["updatedAt"])
|
||||||
|
|
||||||
# check if the updatedAt is within the last 7 days
|
# check if the updatedAt is within the last 7 days
|
||||||
# if yes, add it to discussion_numbers
|
# if yes, add it to discussion_numbers
|
||||||
if discussion_updated_at > since:
|
if discussion_updated_at > since:
|
||||||
if discussion['authorAssociation'] != 'MEMBER':
|
if discussion["authorAssociation"] != "MEMBER":
|
||||||
discussion_numbers.append(discussion['number'])
|
discussion_numbers.append(discussion["number"])
|
||||||
else:
|
else:
|
||||||
found_discussion_out_of_time_range = True
|
found_discussion_out_of_time_range = True
|
||||||
|
|
||||||
@ -256,7 +256,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# update cursor
|
# update cursor
|
||||||
cursor = edges[-1]['cursor']
|
cursor = edges[-1]["cursor"]
|
||||||
|
|
||||||
# get the discussion comments and replies made by our member
|
# get the discussion comments and replies made by our member
|
||||||
user_engagement_count = {}
|
user_engagement_count = {}
|
||||||
@ -269,42 +269,42 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
|||||||
data = _call_graphql_api(query)
|
data = _call_graphql_api(query)
|
||||||
|
|
||||||
# get the comments
|
# get the comments
|
||||||
edges = data['data']['repository']['discussion']['comments']['edges']
|
edges = data["data"]["repository"]["discussion"]["comments"]["edges"]
|
||||||
|
|
||||||
# update the cursor
|
# update the cursor
|
||||||
if len(edges) == 0:
|
if len(edges) == 0:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# update cursor for pagination
|
# update cursor for pagination
|
||||||
cursor = edges[-1]['cursor']
|
cursor = edges[-1]["cursor"]
|
||||||
|
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
comment = edge['node']
|
comment = edge["node"]
|
||||||
if comment['authorAssociation'] == 'MEMBER':
|
if comment["authorAssociation"] == "MEMBER":
|
||||||
# check if the updatedAt is within the last 7 days
|
# check if the updatedAt is within the last 7 days
|
||||||
# if yes, add it to user_engagement_count
|
# if yes, add it to user_engagement_count
|
||||||
comment_updated_at = datetime.strptime(comment['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
|
comment_updated_at = datetime.strptime(comment["updatedAt"], "%Y-%m-%dT%H:%M:%SZ")
|
||||||
if comment_updated_at > since:
|
if comment_updated_at > since:
|
||||||
member_name = comment['author']['login']
|
member_name = comment["author"]["login"]
|
||||||
if member_name in user_engagement_count:
|
if member_name in user_engagement_count:
|
||||||
user_engagement_count[member_name] += 1
|
user_engagement_count[member_name] += 1
|
||||||
else:
|
else:
|
||||||
user_engagement_count[member_name] = 1
|
user_engagement_count[member_name] = 1
|
||||||
|
|
||||||
# get the replies
|
# get the replies
|
||||||
reply_edges = comment['replies']['edges']
|
reply_edges = comment["replies"]["edges"]
|
||||||
if len(reply_edges) == 0:
|
if len(reply_edges) == 0:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
for reply_edge in reply_edges:
|
for reply_edge in reply_edges:
|
||||||
reply = reply_edge['node']
|
reply = reply_edge["node"]
|
||||||
if reply['authorAssociation'] == 'MEMBER':
|
if reply["authorAssociation"] == "MEMBER":
|
||||||
# check if the updatedAt is within the last 7 days
|
# check if the updatedAt is within the last 7 days
|
||||||
# if yes, add it to discussion_numbers
|
# if yes, add it to discussion_numbers
|
||||||
|
|
||||||
reply_updated_at = datetime.strptime(reply['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
|
reply_updated_at = datetime.strptime(reply["updatedAt"], "%Y-%m-%dT%H:%M:%SZ")
|
||||||
if reply_updated_at > since:
|
if reply_updated_at > since:
|
||||||
member_name = reply['author']['login']
|
member_name = reply["author"]["login"]
|
||||||
if member_name in user_engagement_count:
|
if member_name in user_engagement_count:
|
||||||
user_engagement_count[member_name] += 1
|
user_engagement_count[member_name] += 1
|
||||||
else:
|
else:
|
||||||
@ -312,7 +312,9 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
|||||||
return user_engagement_count
|
return user_engagement_count
|
||||||
|
|
||||||
|
|
||||||
def generate_user_engagement_leaderboard_image(github_token: str, org_name: str, repo_list: List[str], output_path: str) -> bool:
|
def generate_user_engagement_leaderboard_image(
|
||||||
|
github_token: str, org_name: str, repo_list: List[str], output_path: str
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Generate the user engagement leaderboard image for stats within the last 7 days
|
Generate the user engagement leaderboard image for stats within the last 7 days
|
||||||
|
|
||||||
@ -335,11 +337,14 @@ def generate_user_engagement_leaderboard_image(github_token: str, org_name: str,
|
|||||||
else:
|
else:
|
||||||
total_engagement_count[name] = count
|
total_engagement_count[name] = count
|
||||||
|
|
||||||
|
|
||||||
for repo_name in repo_list:
|
for repo_name in repo_list:
|
||||||
print(f"Fetching user engagement count for {repo_name}/{repo_name}")
|
print(f"Fetching user engagement count for {repo_name}/{repo_name}")
|
||||||
issue_pr_engagement_count = get_issue_pull_request_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str)
|
issue_pr_engagement_count = get_issue_pull_request_comments(
|
||||||
discussion_engagement_count = get_discussion_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime)
|
github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str
|
||||||
|
)
|
||||||
|
discussion_engagement_count = get_discussion_comments(
|
||||||
|
github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime
|
||||||
|
)
|
||||||
|
|
||||||
# update the total engagement count
|
# update the total engagement count
|
||||||
_update_count(issue_pr_engagement_count)
|
_update_count(issue_pr_engagement_count)
|
||||||
@ -363,7 +368,7 @@ def generate_user_engagement_leaderboard_image(github_token: str, org_name: str,
|
|||||||
# plot the leaderboard
|
# plot the leaderboard
|
||||||
xlabel = f"Number of Comments made (since {start_datetime_str})"
|
xlabel = f"Number of Comments made (since {start_datetime_str})"
|
||||||
ylabel = "Member"
|
ylabel = "Member"
|
||||||
title = 'Active User Engagement Leaderboard'
|
title = "Active User Engagement Leaderboard"
|
||||||
plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
|
plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
@ -380,16 +385,16 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
|
|||||||
"""
|
"""
|
||||||
# request to the Github API to get the users who have contributed in the last 7 days
|
# request to the Github API to get the users who have contributed in the last 7 days
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': f'Bearer {github_token}',
|
"Authorization": f"Bearer {github_token}",
|
||||||
'Accept': 'application/vnd.github+json',
|
"Accept": "application/vnd.github+json",
|
||||||
'X-GitHub-Api-Version': '2022-11-28'
|
"X-GitHub-Api-Version": "2022-11-28",
|
||||||
}
|
}
|
||||||
|
|
||||||
counter = Counter()
|
counter = Counter()
|
||||||
start_datetime = get_utc_time_one_week_ago()
|
start_datetime = get_utc_time_one_week_ago()
|
||||||
|
|
||||||
def _get_url(org_name, repo_name, page):
|
def _get_url(org_name, repo_name, page):
|
||||||
return f'https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed'
|
return f"https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed"
|
||||||
|
|
||||||
def _iterate_by_page(org_name, repo_name):
|
def _iterate_by_page(org_name, repo_name):
|
||||||
page = 1
|
page = 1
|
||||||
@ -415,8 +420,8 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
|
|||||||
|
|
||||||
# count the pull request and author from response
|
# count the pull request and author from response
|
||||||
for pr_data in response:
|
for pr_data in response:
|
||||||
merged_at = pr_data['merged_at']
|
merged_at = pr_data["merged_at"]
|
||||||
author = pr_data['user']['login']
|
author = pr_data["user"]["login"]
|
||||||
|
|
||||||
if merged_at is None:
|
if merged_at is None:
|
||||||
continue
|
continue
|
||||||
@ -439,7 +444,7 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
|
|||||||
_iterate_by_page(org_name, repo_name)
|
_iterate_by_page(org_name, repo_name)
|
||||||
|
|
||||||
# convert unix timestamp to Beijing datetime
|
# convert unix timestamp to Beijing datetime
|
||||||
bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone('Asia/Shanghai'))
|
bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone("Asia/Shanghai"))
|
||||||
bj_start_datetime_str = datetime2str(bj_start_datetime)
|
bj_start_datetime_str = datetime2str(bj_start_datetime)
|
||||||
|
|
||||||
contribution_list = counter.to_sorted_list()
|
contribution_list = counter.to_sorted_list()
|
||||||
@ -452,7 +457,7 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
|
|||||||
if len(author_list) > 0:
|
if len(author_list) > 0:
|
||||||
xlabel = f"Number of Pull Requests (since {bj_start_datetime_str})"
|
xlabel = f"Number of Pull Requests (since {bj_start_datetime_str})"
|
||||||
ylabel = "Contributor"
|
ylabel = "Contributor"
|
||||||
title = 'Active Contributor Leaderboard'
|
title = "Active Contributor Leaderboard"
|
||||||
plot_bar_chart(num_commit_list, author_list, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
|
plot_bar_chart(num_commit_list, author_list, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
@ -468,14 +473,14 @@ def upload_image_to_lark(lark_tenant_token: str, image_path: str) -> str:
|
|||||||
image_path (str): the path to the image to be uploaded
|
image_path (str): the path to the image to be uploaded
|
||||||
"""
|
"""
|
||||||
url = "https://open.feishu.cn/open-apis/im/v1/images"
|
url = "https://open.feishu.cn/open-apis/im/v1/images"
|
||||||
form = {'image_type': 'message', 'image': (open(image_path, 'rb'))} # 需要替换具体的path
|
form = {"image_type": "message", "image": (open(image_path, "rb"))} # 需要替换具体的path
|
||||||
multi_form = MultipartEncoder(form)
|
multi_form = MultipartEncoder(form)
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': f'Bearer {lark_tenant_token}', ## 获取tenant_access_token, 需要替换为实际的token
|
"Authorization": f"Bearer {lark_tenant_token}", ## 获取tenant_access_token, 需要替换为实际的token
|
||||||
}
|
}
|
||||||
headers['Content-Type'] = multi_form.content_type
|
headers["Content-Type"] = multi_form.content_type
|
||||||
response = requests.request("POST", url, headers=headers, data=multi_form).json()
|
response = requests.request("POST", url, headers=headers, data=multi_form).json()
|
||||||
return response['data']['image_key']
|
return response["data"]["image_key"]
|
||||||
|
|
||||||
|
|
||||||
def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
|
def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
|
||||||
@ -486,10 +491,10 @@ def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
|
|||||||
app_id (str): Lark app id
|
app_id (str): Lark app id
|
||||||
app_secret (str): Lark app secret
|
app_secret (str): Lark app secret
|
||||||
"""
|
"""
|
||||||
url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal'
|
url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal"
|
||||||
data = {'app_id': app_id, 'app_secret': app_secret}
|
data = {"app_id": app_id, "app_secret": app_secret}
|
||||||
response = requests.post(url, json=data).json()
|
response = requests.post(url, json=data).json()
|
||||||
return response['tenant_access_token']
|
return response["tenant_access_token"]
|
||||||
|
|
||||||
|
|
||||||
def send_image_to_lark(image_key: str, webhook_url: str) -> None:
|
def send_image_to_lark(image_key: str, webhook_url: str) -> None:
|
||||||
@ -516,10 +521,10 @@ def send_message_to_lark(message: str, webhook_url: str):
|
|||||||
requests.post(webhook_url, json=data)
|
requests.post(webhook_url, json=data)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
GITHUB_TOKEN = os.environ['GITHUB_TOKEN']
|
GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
|
||||||
CONTRIBUTOR_IMAGE_PATH = 'contributor_leaderboard.png'
|
CONTRIBUTOR_IMAGE_PATH = "contributor_leaderboard.png"
|
||||||
USER_ENGAGEMENT_IMAGE_PATH = 'engagement_leaderboard.png'
|
USER_ENGAGEMENT_IMAGE_PATH = "engagement_leaderboard.png"
|
||||||
ORG_NAME = "hpcaitech"
|
ORG_NAME = "hpcaitech"
|
||||||
|
|
||||||
# get all open source repositories
|
# get all open source repositories
|
||||||
@ -527,17 +532,19 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# generate images
|
# generate images
|
||||||
contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, CONTRIBUTOR_IMAGE_PATH)
|
contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, CONTRIBUTOR_IMAGE_PATH)
|
||||||
engagement_success = generate_user_engagement_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH)
|
engagement_success = generate_user_engagement_leaderboard_image(
|
||||||
|
GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH
|
||||||
|
)
|
||||||
|
|
||||||
# upload images
|
# upload images
|
||||||
APP_ID = os.environ['LARK_APP_ID']
|
APP_ID = os.environ["LARK_APP_ID"]
|
||||||
APP_SECRET = os.environ['LARK_APP_SECRET']
|
APP_SECRET = os.environ["LARK_APP_SECRET"]
|
||||||
LARK_TENANT_TOKEN = generate_lark_tenant_access_token(app_id=APP_ID, app_secret=APP_SECRET)
|
LARK_TENANT_TOKEN = generate_lark_tenant_access_token(app_id=APP_ID, app_secret=APP_SECRET)
|
||||||
contributor_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, CONTRIBUTOR_IMAGE_PATH)
|
contributor_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, CONTRIBUTOR_IMAGE_PATH)
|
||||||
user_engagement_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, USER_ENGAGEMENT_IMAGE_PATH)
|
user_engagement_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, USER_ENGAGEMENT_IMAGE_PATH)
|
||||||
|
|
||||||
# send message to lark
|
# send message to lark
|
||||||
LARK_WEBHOOK_URL = os.environ['LARK_WEBHOOK_URL']
|
LARK_WEBHOOK_URL = os.environ["LARK_WEBHOOK_URL"]
|
||||||
message = """本周的社区榜单出炉啦!
|
message = """本周的社区榜单出炉啦!
|
||||||
1. 开发贡献者榜单
|
1. 开发贡献者榜单
|
||||||
2. 用户互动榜单
|
2. 用户互动榜单
|
||||||
|
@ -7,27 +7,27 @@ import re
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits'
|
COMMIT_API = "https://api.github.com/repos/hpcaitech/ColossalAI/commits"
|
||||||
TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags'
|
TAGS_API = "https://api.github.com/repos/hpcaitech/ColossalAI/tags"
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--out', type=str, help='output path for the release draft', required=True)
|
parser.add_argument("--out", type=str, help="output path for the release draft", required=True)
|
||||||
parser.add_argument('--version', type=str, help='current version to release', required=True)
|
parser.add_argument("--version", type=str, help="current version to release", required=True)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def get_latest_tag_commit(headers=None):
|
def get_latest_tag_commit(headers=None):
|
||||||
res = requests.get(url=TAGS_API, headers=headers)
|
res = requests.get(url=TAGS_API, headers=headers)
|
||||||
data = res.json()
|
data = res.json()
|
||||||
commit_hash = data[0]['commit']['sha']
|
commit_hash = data[0]["commit"]["sha"]
|
||||||
version = data[0]['name']
|
version = data[0]["name"]
|
||||||
return commit_hash, version
|
return commit_hash, version
|
||||||
|
|
||||||
|
|
||||||
def get_commit_info(commit_hash, headers=None):
|
def get_commit_info(commit_hash, headers=None):
|
||||||
api = f'{COMMIT_API}/{commit_hash}'
|
api = f"{COMMIT_API}/{commit_hash}"
|
||||||
res = requests.get(url=api, headers=headers)
|
res = requests.get(url=api, headers=headers)
|
||||||
return res.json()
|
return res.json()
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ def get_all_commit_info(since, headers=None):
|
|||||||
results = []
|
results = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
api = f'{COMMIT_API}?since={since}&per_page=100&page={page}'
|
api = f"{COMMIT_API}?since={since}&per_page=100&page={page}"
|
||||||
resp = requests.get(url=api, headers=headers)
|
resp = requests.get(url=api, headers=headers)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
@ -53,21 +53,21 @@ def get_all_commit_info(since, headers=None):
|
|||||||
|
|
||||||
def collate_release_info(commit_info_list):
|
def collate_release_info(commit_info_list):
|
||||||
results = dict()
|
results = dict()
|
||||||
pattern = pattern = r'\[.*\]'
|
pattern = pattern = r"\[.*\]"
|
||||||
|
|
||||||
for commit_info in commit_info_list:
|
for commit_info in commit_info_list:
|
||||||
author = commit_info['commit']['author']['name']
|
author = commit_info["commit"]["author"]["name"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
author_url = commit_info['author']['url']
|
author_url = commit_info["author"]["url"]
|
||||||
except:
|
except:
|
||||||
# author can be None
|
# author can be None
|
||||||
author_url = None
|
author_url = None
|
||||||
msg = commit_info['commit']['message']
|
msg = commit_info["commit"]["message"]
|
||||||
match = re.search(pattern, msg)
|
match = re.search(pattern, msg)
|
||||||
|
|
||||||
if match:
|
if match:
|
||||||
tag = match.group().lstrip('[').rstrip(']').capitalize()
|
tag = match.group().lstrip("[").rstrip("]").capitalize()
|
||||||
if tag not in results:
|
if tag not in results:
|
||||||
results[tag] = []
|
results[tag] = []
|
||||||
results[tag].append((msg, author, author_url))
|
results[tag].append((msg, author, author_url))
|
||||||
@ -89,42 +89,43 @@ def generate_release_post_markdown(current_version, last_version, release_info):
|
|||||||
|
|
||||||
for msg, author, author_url in v:
|
for msg, author, author_url in v:
|
||||||
# only keep the first line
|
# only keep the first line
|
||||||
msg = msg.split('\n')[0]
|
msg = msg.split("\n")[0]
|
||||||
|
|
||||||
if author_url:
|
if author_url:
|
||||||
item = f'{msg} by [{author}]({author_url})\n'
|
item = f"{msg} by [{author}]({author_url})\n"
|
||||||
else:
|
else:
|
||||||
item = f'{msg} by {author}\n'
|
item = f"{msg} by {author}\n"
|
||||||
text.append(f'- {item}')
|
text.append(f"- {item}")
|
||||||
|
|
||||||
text.append('\n')
|
text.append("\n")
|
||||||
|
|
||||||
# add full change log
|
# add full change log
|
||||||
text.append(
|
text.append(
|
||||||
f'**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}')
|
f"**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}"
|
||||||
|
)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
token = os.environ['GITHUB_API_TOKEN']
|
token = os.environ["GITHUB_API_TOKEN"]
|
||||||
headers = {'Authorization': token}
|
headers = {"Authorization": token}
|
||||||
|
|
||||||
# get previous release tag
|
# get previous release tag
|
||||||
last_release_commit, last_version = get_latest_tag_commit(headers)
|
last_release_commit, last_version = get_latest_tag_commit(headers)
|
||||||
last_release_commit_info = get_commit_info(last_release_commit, headers=headers)
|
last_release_commit_info = get_commit_info(last_release_commit, headers=headers)
|
||||||
last_release_date = last_release_commit_info['commit']['author']['date']
|
last_release_date = last_release_commit_info["commit"]["author"]["date"]
|
||||||
|
|
||||||
# get the commits since last release
|
# get the commits since last release
|
||||||
commit_info = get_all_commit_info(since=last_release_date, headers=headers)
|
commit_info = get_all_commit_info(since=last_release_date, headers=headers)
|
||||||
commit_info = commit_info[:-1] # remove the release commit
|
commit_info = commit_info[:-1] # remove the release commit
|
||||||
|
|
||||||
# collate into markdown
|
# collate into markdown
|
||||||
release_info = collate_release_info(commit_info)
|
release_info = collate_release_info(commit_info)
|
||||||
markdown_text = generate_release_post_markdown(args.version, last_version, release_info)
|
markdown_text = generate_release_post_markdown(args.version, last_version, release_info)
|
||||||
|
|
||||||
# write into a file
|
# write into a file
|
||||||
with open(args.out, 'w') as f:
|
with open(args.out, "w") as f:
|
||||||
for line in markdown_text:
|
for line in markdown_text:
|
||||||
f.write(line)
|
f.write(line)
|
||||||
|
@ -5,8 +5,8 @@ import requests
|
|||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-m', '--message', type=str)
|
parser.add_argument("-m", "--message", type=str)
|
||||||
parser.add_argument('-u', '--url', type=str)
|
parser.add_argument("-u", "--url", type=str)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -15,6 +15,6 @@ def send_message_to_lark(message, webhook_url):
|
|||||||
requests.post(webhook_url, json=data)
|
requests.post(webhook_url, json=data)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
send_message_to_lark(args.message, args.url)
|
send_message_to_lark(args.message, args.url)
|
||||||
|
@ -3,3 +3,4 @@ line_length = 120
|
|||||||
multi_line_output=3
|
multi_line_output=3
|
||||||
include_trailing_comma = true
|
include_trailing_comma = true
|
||||||
ignore_comments = true
|
ignore_comments = true
|
||||||
|
profile = black
|
||||||
|
@ -1,23 +1,31 @@
|
|||||||
repos:
|
repos:
|
||||||
|
|
||||||
|
- repo: https://github.com/PyCQA/autoflake
|
||||||
|
rev: v2.2.1
|
||||||
|
hooks:
|
||||||
|
- id: autoflake
|
||||||
|
name: autoflake (python)
|
||||||
|
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.12.0
|
rev: 5.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
name: sort all imports (python)
|
name: sort all imports (python)
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: v0.32.0
|
rev: 23.9.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: black
|
||||||
name: yapf formatter
|
name: black formatter
|
||||||
args: ['--style=.style.yapf', '--parallel', '--in-place']
|
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v13.0.1
|
rev: v13.0.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
name: clang formatter
|
name: clang formatter
|
||||||
|
types_or: [c++, c]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.3.0
|
rev: v4.3.0
|
||||||
|
@ -1,5 +0,0 @@
|
|||||||
[style]
|
|
||||||
based_on_style = google
|
|
||||||
spaces_before_comment = 4
|
|
||||||
split_before_logical_operator = true
|
|
||||||
column_limit = 120
|
|
@ -27,7 +27,7 @@ def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
|||||||
def preprocess_batch(samples) -> dict:
|
def preprocess_batch(samples) -> dict:
|
||||||
input_ids = torch.stack(samples)
|
input_ids = torch.stack(samples)
|
||||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
|
|
||||||
|
|
||||||
def print_rank_0(*args, **kwargs) -> None:
|
def print_rank_0(*args, **kwargs) -> None:
|
||||||
@ -39,32 +39,32 @@ def print_model_numel(model_dict: dict) -> None:
|
|||||||
B = 1024**3
|
B = 1024**3
|
||||||
M = 1024**2
|
M = 1024**2
|
||||||
K = 1024
|
K = 1024
|
||||||
outputs = ''
|
outputs = ""
|
||||||
for name, numel in model_dict.items():
|
for name, numel in model_dict.items():
|
||||||
outputs += f'{name}: '
|
outputs += f"{name}: "
|
||||||
if numel >= B:
|
if numel >= B:
|
||||||
outputs += f'{numel / B:.2f} B\n'
|
outputs += f"{numel / B:.2f} B\n"
|
||||||
elif numel >= M:
|
elif numel >= M:
|
||||||
outputs += f'{numel / M:.2f} M\n'
|
outputs += f"{numel / M:.2f} M\n"
|
||||||
elif numel >= K:
|
elif numel >= K:
|
||||||
outputs += f'{numel / K:.2f} K\n'
|
outputs += f"{numel / K:.2f} K\n"
|
||||||
else:
|
else:
|
||||||
outputs += f'{numel}\n'
|
outputs += f"{numel}\n"
|
||||||
print_rank_0(outputs)
|
print_rank_0(outputs)
|
||||||
|
|
||||||
|
|
||||||
def get_gpt_config(model_name: str) -> OPTConfig:
|
def get_gpt_config(model_name: str) -> OPTConfig:
|
||||||
model_map = {
|
model_map = {
|
||||||
'125m': OPTConfig.from_pretrained('facebook/opt-125m'),
|
"125m": OPTConfig.from_pretrained("facebook/opt-125m"),
|
||||||
'350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
|
"350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
|
||||||
'700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
|
"700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
|
||||||
'1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
|
"1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"),
|
||||||
'2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
|
"2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"),
|
||||||
'3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
|
"3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
|
||||||
'5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
|
"5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
|
||||||
'6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
|
"6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"),
|
||||||
'10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
|
"10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
|
||||||
'13b': OPTConfig.from_pretrained('facebook/opt-13b'),
|
"13b": OPTConfig.from_pretrained("facebook/opt-13b"),
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
return model_map[model_name]
|
return model_map[model_name]
|
||||||
@ -73,20 +73,20 @@ def get_gpt_config(model_name: str) -> OPTConfig:
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
if args.strategy == 'ddp':
|
if args.strategy == "ddp":
|
||||||
strategy = DDPStrategy()
|
strategy = DDPStrategy()
|
||||||
elif args.strategy == 'colossalai_gemini':
|
elif args.strategy == "colossalai_gemini":
|
||||||
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||||
elif args.strategy == 'colossalai_gemini_cpu':
|
elif args.strategy == "colossalai_gemini_cpu":
|
||||||
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
|
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||||
elif args.strategy == 'colossalai_zero2':
|
elif args.strategy == "colossalai_zero2":
|
||||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||||
elif args.strategy == 'colossalai_zero2_cpu':
|
elif args.strategy == "colossalai_zero2_cpu":
|
||||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||||
elif args.strategy == 'colossalai_zero1':
|
elif args.strategy == "colossalai_zero1":
|
||||||
strategy = LowLevelZeroStrategy(stage=1, placement_policy='cuda')
|
strategy = LowLevelZeroStrategy(stage=1, placement_policy="cuda")
|
||||||
elif args.strategy == 'colossalai_zero1_cpu':
|
elif args.strategy == "colossalai_zero1_cpu":
|
||||||
strategy = LowLevelZeroStrategy(stage=1, placement_policy='cpu')
|
strategy = LowLevelZeroStrategy(stage=1, placement_policy="cpu")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||||
|
|
||||||
@ -103,90 +103,106 @@ def main(args):
|
|||||||
|
|
||||||
if args.use_kernels:
|
if args.use_kernels:
|
||||||
from coati.kernels import convert_to_xformer_model
|
from coati.kernels import convert_to_xformer_model
|
||||||
actor, critic, initial_model, reward_model = map(convert_to_xformer_model,
|
|
||||||
(actor, critic, initial_model, reward_model))
|
actor, critic, initial_model, reward_model = map(
|
||||||
|
convert_to_xformer_model, (actor, critic, initial_model, reward_model)
|
||||||
|
)
|
||||||
|
|
||||||
actor_numel = get_model_numel(actor, strategy)
|
actor_numel = get_model_numel(actor, strategy)
|
||||||
critic_numel = get_model_numel(critic, strategy)
|
critic_numel = get_model_numel(critic, strategy)
|
||||||
initial_model_numel = get_model_numel(initial_model, strategy)
|
initial_model_numel = get_model_numel(initial_model, strategy)
|
||||||
reward_model_numel = get_model_numel(reward_model, strategy)
|
reward_model_numel = get_model_numel(reward_model, strategy)
|
||||||
print_model_numel({
|
print_model_numel(
|
||||||
'Actor': actor_numel,
|
{
|
||||||
'Critic': critic_numel,
|
"Actor": actor_numel,
|
||||||
'Initial model': initial_model_numel,
|
"Critic": critic_numel,
|
||||||
'Reward model': reward_model_numel
|
"Initial model": initial_model_numel,
|
||||||
})
|
"Reward model": reward_model_numel,
|
||||||
performance_evaluator = PerformanceEvaluator(actor_numel,
|
}
|
||||||
critic_numel,
|
)
|
||||||
initial_model_numel,
|
performance_evaluator = PerformanceEvaluator(
|
||||||
reward_model_numel,
|
actor_numel,
|
||||||
enable_grad_checkpoint=False,
|
critic_numel,
|
||||||
ignore_episodes=1)
|
initial_model_numel,
|
||||||
|
reward_model_numel,
|
||||||
|
enable_grad_checkpoint=False,
|
||||||
|
ignore_episodes=1,
|
||||||
|
)
|
||||||
|
|
||||||
if args.strategy.startswith('colossalai'):
|
if args.strategy.startswith("colossalai"):
|
||||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||||
else:
|
else:
|
||||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||||
|
|
||||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
|
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
|
||||||
dataloader = DataLoader(random_prompts,
|
dataloader = DataLoader(
|
||||||
batch_size=args.experience_batch_size,
|
random_prompts, batch_size=args.experience_batch_size, shuffle=True, collate_fn=preprocess_batch
|
||||||
shuffle=True,
|
)
|
||||||
collate_fn=preprocess_batch)
|
|
||||||
|
|
||||||
trainer = PPOTrainer(strategy,
|
trainer = PPOTrainer(
|
||||||
actor,
|
strategy,
|
||||||
critic,
|
actor,
|
||||||
reward_model,
|
critic,
|
||||||
initial_model,
|
reward_model,
|
||||||
actor_optim,
|
initial_model,
|
||||||
critic_optim,
|
actor_optim,
|
||||||
ptx_coef=0,
|
critic_optim,
|
||||||
train_batch_size=args.train_batch_size,
|
ptx_coef=0,
|
||||||
offload_inference_models=args.offload_inference_models,
|
train_batch_size=args.train_batch_size,
|
||||||
max_length=512,
|
offload_inference_models=args.offload_inference_models,
|
||||||
do_sample=True,
|
max_length=512,
|
||||||
temperature=1.0,
|
do_sample=True,
|
||||||
top_k=50,
|
temperature=1.0,
|
||||||
use_cache=True,
|
top_k=50,
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
use_cache=True,
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
callbacks=[performance_evaluator])
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
callbacks=[performance_evaluator],
|
||||||
|
)
|
||||||
|
|
||||||
trainer.fit(prompt_dataloader=dataloader,
|
trainer.fit(
|
||||||
pretrain_dataloader=None,
|
prompt_dataloader=dataloader,
|
||||||
num_episodes=args.num_episodes,
|
pretrain_dataloader=None,
|
||||||
num_update_steps=args.num_update_steps,
|
num_episodes=args.num_episodes,
|
||||||
num_collect_steps=args.num_collect_steps)
|
num_update_steps=args.num_update_steps,
|
||||||
|
num_collect_steps=args.num_collect_steps,
|
||||||
|
)
|
||||||
|
|
||||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
print_rank_0(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', default='125m')
|
parser.add_argument("--model", default="125m")
|
||||||
parser.add_argument('--critic_model', default='125m')
|
parser.add_argument("--critic_model", default="125m")
|
||||||
parser.add_argument('--strategy',
|
parser.add_argument(
|
||||||
choices=[
|
"--strategy",
|
||||||
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
|
choices=[
|
||||||
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
|
"ddp",
|
||||||
],
|
"colossalai_gemini",
|
||||||
default='ddp')
|
"colossalai_gemini_cpu",
|
||||||
parser.add_argument('--num_episodes', type=int, default=3)
|
"colossalai_zero2",
|
||||||
parser.add_argument('--num_collect_steps', type=int, default=8)
|
"colossalai_zero2_cpu",
|
||||||
parser.add_argument('--num_update_steps', type=int, default=1)
|
"colossalai_zero1",
|
||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
"colossalai_zero1_cpu",
|
||||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
],
|
||||||
parser.add_argument('--lora_rank', type=int, default=0)
|
default="ddp",
|
||||||
parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
|
)
|
||||||
parser.add_argument('--offload_inference_models', action='store_true', default=False)
|
parser.add_argument("--num_episodes", type=int, default=3)
|
||||||
parser.add_argument('--use_kernels', action='store_true', default=False)
|
parser.add_argument("--num_collect_steps", type=int, default=8)
|
||||||
|
parser.add_argument("--num_update_steps", type=int, default=1)
|
||||||
|
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||||
|
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||||
|
parser.add_argument("--lora_rank", type=int, default=0)
|
||||||
|
parser.add_argument("--cuda_mem_frac", type=float, default=1.0)
|
||||||
|
parser.add_argument("--offload_inference_models", action="store_true", default=False)
|
||||||
|
parser.add_argument("--use_kernels", action="store_true", default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -22,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
|
|||||||
|
|
||||||
def get_free_port():
|
def get_free_port():
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
s.bind(('', 0))
|
s.bind(("", 0))
|
||||||
return s.getsockname()[1]
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
def get_local_ip():
|
def get_local_ip():
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||||
s.connect(('8.8.8.8', 80))
|
s.connect(("8.8.8.8", 80))
|
||||||
return s.getsockname()[0]
|
return s.getsockname()[0]
|
||||||
|
|
||||||
|
|
||||||
@ -36,22 +36,25 @@ def main(args):
|
|||||||
master_addr = str(get_local_ip())
|
master_addr = str(get_local_ip())
|
||||||
# trainer_env_info
|
# trainer_env_info
|
||||||
trainer_port = str(get_free_port())
|
trainer_port = str(get_free_port())
|
||||||
env_info_trainers = [{
|
env_info_trainers = [
|
||||||
'local_rank': '0',
|
{
|
||||||
'rank': str(rank),
|
"local_rank": "0",
|
||||||
'world_size': str(args.num_trainers),
|
"rank": str(rank),
|
||||||
'master_port': trainer_port,
|
"world_size": str(args.num_trainers),
|
||||||
'master_addr': master_addr
|
"master_port": trainer_port,
|
||||||
} for rank in range(args.num_trainers)]
|
"master_addr": master_addr,
|
||||||
|
}
|
||||||
|
for rank in range(args.num_trainers)
|
||||||
|
]
|
||||||
|
|
||||||
# maker_env_info
|
# maker_env_info
|
||||||
maker_port = str(get_free_port())
|
maker_port = str(get_free_port())
|
||||||
env_info_maker = {
|
env_info_maker = {
|
||||||
'local_rank': '0',
|
"local_rank": "0",
|
||||||
'rank': '0',
|
"rank": "0",
|
||||||
'world_size': '1',
|
"world_size": "1",
|
||||||
'master_port': maker_port,
|
"master_port": maker_port,
|
||||||
'master_addr': master_addr
|
"master_addr": master_addr,
|
||||||
}
|
}
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
@ -63,21 +66,27 @@ def main(args):
|
|||||||
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
|
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
|
||||||
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
||||||
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
|
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
|
||||||
reward_model = get_reward_model_from_args(args.critic_model,
|
reward_model = (
|
||||||
config=critic_cfg).requires_grad_(False).half().cuda()
|
get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
|
||||||
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
|
)
|
||||||
|
if args.initial_model_quant_ckpt is not None and args.model == "llama":
|
||||||
# quantize initial model
|
# quantize initial model
|
||||||
with low_resource_init(), no_init_weights():
|
with low_resource_init(), no_init_weights():
|
||||||
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
||||||
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
|
initial_model.model = (
|
||||||
args.quant_group_size).cuda().requires_grad_(False)
|
llama_load_quant(
|
||||||
|
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
|
||||||
|
)
|
||||||
|
.cuda()
|
||||||
|
.requires_grad_(False)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
||||||
return actor, critic, reward_model, initial_model
|
return actor, critic, reward_model, initial_model
|
||||||
|
|
||||||
# configure Experience Maker
|
# configure Experience Maker
|
||||||
experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote(
|
experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote(
|
||||||
detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
|
detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
|
||||||
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
||||||
model_fn=model_fn,
|
model_fn=model_fn,
|
||||||
env_info=env_info_maker,
|
env_info=env_info_maker,
|
||||||
@ -97,15 +106,18 @@ def main(args):
|
|||||||
|
|
||||||
def trainer_model_fn():
|
def trainer_model_fn():
|
||||||
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
|
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
|
||||||
critic = get_critic_from_args(args.critic_model,
|
critic = (
|
||||||
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
|
get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
|
||||||
|
.half()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
return actor, critic
|
return actor, critic
|
||||||
|
|
||||||
# configure Trainer
|
# configure Trainer
|
||||||
trainer_refs = [
|
trainer_refs = [
|
||||||
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
|
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
|
||||||
experience_maker_holder_name_list=[
|
experience_maker_holder_name_list=[
|
||||||
f'maker{x}' for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
|
f"maker{x}" for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
|
||||||
],
|
],
|
||||||
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
|
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
|
||||||
model_fn=trainer_model_fn,
|
model_fn=trainer_model_fn,
|
||||||
@ -114,7 +126,8 @@ def main(args):
|
|||||||
buffer_limit=16,
|
buffer_limit=16,
|
||||||
eval_performance=True,
|
eval_performance=True,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
) for i, env_info_trainer in enumerate(env_info_trainers)
|
)
|
||||||
|
for i, env_info_trainer in enumerate(env_info_trainers)
|
||||||
]
|
]
|
||||||
|
|
||||||
dataset_size = args.experience_batch_size * 4
|
dataset_size = args.experience_batch_size * 4
|
||||||
@ -122,7 +135,7 @@ def main(args):
|
|||||||
def data_gen_fn():
|
def data_gen_fn():
|
||||||
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
|
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
|
||||||
attn_mask = torch.ones_like(input_ids)
|
attn_mask = torch.ones_like(input_ids)
|
||||||
return {'input_ids': input_ids, 'attention_mask': attn_mask}
|
return {"input_ids": input_ids, "attention_mask": attn_mask}
|
||||||
|
|
||||||
def build_dataloader(size):
|
def build_dataloader(size):
|
||||||
dataset = [data_gen_fn() for _ in range(size)]
|
dataset = [data_gen_fn() for _ in range(size)]
|
||||||
@ -138,8 +151,10 @@ def main(args):
|
|||||||
wait_tasks = []
|
wait_tasks = []
|
||||||
|
|
||||||
wait_tasks.append(
|
wait_tasks.append(
|
||||||
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
|
experience_holder_ref.workingloop.remote(
|
||||||
num_steps=args.experience_steps))
|
partial(build_dataloader, dataset_size), num_steps=args.experience_steps
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
|
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
|
||||||
for trainer_ref in trainer_refs:
|
for trainer_ref in trainer_refs:
|
||||||
@ -148,31 +163,30 @@ def main(args):
|
|||||||
ray.get(wait_tasks)
|
ray.get(wait_tasks)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--num_trainers', type=int, default=1)
|
parser.add_argument("--num_trainers", type=int, default=1)
|
||||||
parser.add_argument('--trainer_strategy',
|
parser.add_argument(
|
||||||
choices=[
|
"--trainer_strategy",
|
||||||
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
|
||||||
'colossalai_zero2_cpu'
|
default="ddp",
|
||||||
],
|
)
|
||||||
default='ddp')
|
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
|
||||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
parser.add_argument("--pretrain", type=str, default=None)
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
parser.add_argument("--critic_pretrain", type=str, default=None)
|
||||||
parser.add_argument('--critic_pretrain', type=str, default=None)
|
parser.add_argument("--experience_steps", type=int, default=4)
|
||||||
parser.add_argument('--experience_steps', type=int, default=4)
|
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
parser.add_argument("--train_epochs", type=int, default=1)
|
||||||
parser.add_argument('--train_epochs', type=int, default=1)
|
parser.add_argument("--update_steps", type=int, default=2)
|
||||||
parser.add_argument('--update_steps', type=int, default=2)
|
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
|
||||||
|
|
||||||
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
|
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
|
||||||
parser.add_argument('--quant_bits', type=int, default=4)
|
parser.add_argument("--quant_bits", type=int, default=4)
|
||||||
parser.add_argument('--quant_group_size', type=int, default=128)
|
parser.add_argument("--quant_group_size", type=int, default=128)
|
||||||
parser.add_argument('--debug', action='store_true')
|
parser.add_argument("--debug", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -22,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
|
|||||||
|
|
||||||
def get_free_port():
|
def get_free_port():
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
s.bind(('', 0))
|
s.bind(("", 0))
|
||||||
return s.getsockname()[1]
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
def get_local_ip():
|
def get_local_ip():
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||||
s.connect(('8.8.8.8', 80))
|
s.connect(("8.8.8.8", 80))
|
||||||
return s.getsockname()[0]
|
return s.getsockname()[0]
|
||||||
|
|
||||||
|
|
||||||
@ -36,23 +36,29 @@ def main(args):
|
|||||||
master_addr = str(get_local_ip())
|
master_addr = str(get_local_ip())
|
||||||
# trainer_env_info
|
# trainer_env_info
|
||||||
trainer_port = str(get_free_port())
|
trainer_port = str(get_free_port())
|
||||||
env_info_trainers = [{
|
env_info_trainers = [
|
||||||
'local_rank': '0',
|
{
|
||||||
'rank': str(rank),
|
"local_rank": "0",
|
||||||
'world_size': str(args.num_trainers),
|
"rank": str(rank),
|
||||||
'master_port': trainer_port,
|
"world_size": str(args.num_trainers),
|
||||||
'master_addr': master_addr
|
"master_port": trainer_port,
|
||||||
} for rank in range(args.num_trainers)]
|
"master_addr": master_addr,
|
||||||
|
}
|
||||||
|
for rank in range(args.num_trainers)
|
||||||
|
]
|
||||||
|
|
||||||
# maker_env_info
|
# maker_env_info
|
||||||
maker_port = str(get_free_port())
|
maker_port = str(get_free_port())
|
||||||
env_info_makers = [{
|
env_info_makers = [
|
||||||
'local_rank': '0',
|
{
|
||||||
'rank': str(rank),
|
"local_rank": "0",
|
||||||
'world_size': str(args.num_makers),
|
"rank": str(rank),
|
||||||
'master_port': maker_port,
|
"world_size": str(args.num_makers),
|
||||||
'master_addr': master_addr
|
"master_port": maker_port,
|
||||||
} for rank in range(args.num_makers)]
|
"master_addr": master_addr,
|
||||||
|
}
|
||||||
|
for rank in range(args.num_makers)
|
||||||
|
]
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
|
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
|
||||||
@ -63,14 +69,20 @@ def main(args):
|
|||||||
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
|
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
|
||||||
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
||||||
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
|
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
|
||||||
reward_model = get_reward_model_from_args(args.critic_model,
|
reward_model = (
|
||||||
config=critic_cfg).requires_grad_(False).half().cuda()
|
get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
|
||||||
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
|
)
|
||||||
|
if args.initial_model_quant_ckpt is not None and args.model == "llama":
|
||||||
# quantize initial model
|
# quantize initial model
|
||||||
with low_resource_init(), no_init_weights():
|
with low_resource_init(), no_init_weights():
|
||||||
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
||||||
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
|
initial_model.model = (
|
||||||
args.quant_group_size).cuda().requires_grad_(False)
|
llama_load_quant(
|
||||||
|
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
|
||||||
|
)
|
||||||
|
.cuda()
|
||||||
|
.requires_grad_(False)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
||||||
return actor, critic, reward_model, initial_model
|
return actor, critic, reward_model, initial_model
|
||||||
@ -79,7 +91,7 @@ def main(args):
|
|||||||
experience_holder_refs = [
|
experience_holder_refs = [
|
||||||
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
|
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
|
||||||
detached_trainer_name_list=[
|
detached_trainer_name_list=[
|
||||||
f'trainer{x}'
|
f"trainer{x}"
|
||||||
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
|
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
|
||||||
],
|
],
|
||||||
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
||||||
@ -103,8 +115,11 @@ def main(args):
|
|||||||
|
|
||||||
def trainer_model_fn():
|
def trainer_model_fn():
|
||||||
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
|
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
|
||||||
critic = get_critic_from_args(args.critic_model,
|
critic = (
|
||||||
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
|
get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
|
||||||
|
.half()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
return actor, critic
|
return actor, critic
|
||||||
|
|
||||||
# configure Trainer
|
# configure Trainer
|
||||||
@ -130,7 +145,7 @@ def main(args):
|
|||||||
def data_gen_fn():
|
def data_gen_fn():
|
||||||
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
|
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
|
||||||
attn_mask = torch.ones_like(input_ids)
|
attn_mask = torch.ones_like(input_ids)
|
||||||
return {'input_ids': input_ids, 'attention_mask': attn_mask}
|
return {"input_ids": input_ids, "attention_mask": attn_mask}
|
||||||
|
|
||||||
def build_dataloader(size):
|
def build_dataloader(size):
|
||||||
dataset = [data_gen_fn() for _ in range(size)]
|
dataset = [data_gen_fn() for _ in range(size)]
|
||||||
@ -147,43 +162,48 @@ def main(args):
|
|||||||
|
|
||||||
for experience_holder_ref in experience_holder_refs:
|
for experience_holder_ref in experience_holder_refs:
|
||||||
wait_tasks.append(
|
wait_tasks.append(
|
||||||
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
|
experience_holder_ref.workingloop.remote(
|
||||||
num_steps=args.experience_steps))
|
partial(build_dataloader, dataset_size), num_steps=args.experience_steps
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
total_steps = args.experience_batch_size * args.experience_steps * \
|
total_steps = (
|
||||||
args.num_makers // (args.num_trainers * args.train_batch_size)
|
args.experience_batch_size
|
||||||
|
* args.experience_steps
|
||||||
|
* args.num_makers
|
||||||
|
// (args.num_trainers * args.train_batch_size)
|
||||||
|
)
|
||||||
for trainer_ref in trainer_refs:
|
for trainer_ref in trainer_refs:
|
||||||
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
|
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
|
||||||
|
|
||||||
ray.get(wait_tasks)
|
ray.get(wait_tasks)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--num_makers', type=int, default=1)
|
parser.add_argument("--num_makers", type=int, default=1)
|
||||||
parser.add_argument('--num_trainers', type=int, default=1)
|
parser.add_argument("--num_trainers", type=int, default=1)
|
||||||
parser.add_argument('--trainer_strategy',
|
parser.add_argument(
|
||||||
choices=[
|
"--trainer_strategy",
|
||||||
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
|
||||||
'colossalai_zero2_cpu'
|
default="ddp",
|
||||||
],
|
)
|
||||||
default='ddp')
|
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
|
||||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
parser.add_argument("--pretrain", type=str, default=None)
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
parser.add_argument("--critic_pretrain", type=str, default=None)
|
||||||
parser.add_argument('--critic_pretrain', type=str, default=None)
|
parser.add_argument("--experience_steps", type=int, default=4)
|
||||||
parser.add_argument('--experience_steps', type=int, default=4)
|
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
parser.add_argument("--train_epochs", type=int, default=1)
|
||||||
parser.add_argument('--train_epochs', type=int, default=1)
|
parser.add_argument("--update_steps", type=int, default=2)
|
||||||
parser.add_argument('--update_steps', type=int, default=2)
|
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
|
||||||
|
|
||||||
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
|
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
|
||||||
parser.add_argument('--quant_bits', type=int, default=4)
|
parser.add_argument("--quant_bits", type=int, default=4)
|
||||||
parser.add_argument('--quant_group_size', type=int, default=128)
|
parser.add_argument("--quant_group_size", type=int, default=128)
|
||||||
parser.add_argument('--debug', action='store_true')
|
parser.add_argument("--debug", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -4,7 +4,10 @@ from .sft_dataset import SFTDataset, SupervisedDataset
|
|||||||
from .utils import is_rank_0
|
from .utils import is_rank_0
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'RmStaticDataset', 'HhRlhfDataset',
|
"RmStaticDataset",
|
||||||
'SFTDataset', 'SupervisedDataset',
|
"HhRlhfDataset",
|
||||||
'PromptDataset', 'is_rank_0',
|
"SFTDataset",
|
||||||
|
"SupervisedDataset",
|
||||||
|
"PromptDataset",
|
||||||
|
"is_rank_0",
|
||||||
]
|
]
|
||||||
|
@ -49,7 +49,7 @@ class Conversation:
|
|||||||
|
|
||||||
def to_gradio_chatbot(self):
|
def to_gradio_chatbot(self):
|
||||||
ret = []
|
ret = []
|
||||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||||
if i % 2 == 0:
|
if i % 2 == 0:
|
||||||
ret.append([msg, None])
|
ret.append([msg, None])
|
||||||
else:
|
else:
|
||||||
@ -57,12 +57,14 @@ class Conversation:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return Conversation(system=self.system,
|
return Conversation(
|
||||||
roles=self.roles,
|
system=self.system,
|
||||||
messages=[[x, y] for x, y in self.messages],
|
roles=self.roles,
|
||||||
offset=self.offset,
|
messages=[[x, y] for x, y in self.messages],
|
||||||
sep_style=self.sep_style,
|
offset=self.offset,
|
||||||
sep=self.sep)
|
sep_style=self.sep_style,
|
||||||
|
sep=self.sep,
|
||||||
|
)
|
||||||
|
|
||||||
def dict(self):
|
def dict(self):
|
||||||
return {
|
return {
|
||||||
@ -70,7 +72,7 @@ class Conversation:
|
|||||||
"roles": self.roles,
|
"roles": self.roles,
|
||||||
"messages": self.messages,
|
"messages": self.messages,
|
||||||
"offset": self.offset,
|
"offset": self.offset,
|
||||||
"sep": self.sep
|
"sep": self.sep,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,11 +13,13 @@ from .utils import jload
|
|||||||
class PromptDataset(Dataset):
|
class PromptDataset(Dataset):
|
||||||
"""Dataset for supervised fine-tuning."""
|
"""Dataset for supervised fine-tuning."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
data_path: str,
|
self,
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
data_path: str,
|
||||||
max_datasets_size: int = None,
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
max_length: int = 96):
|
max_datasets_size: int = None,
|
||||||
|
max_length: int = 96,
|
||||||
|
):
|
||||||
super(PromptDataset, self).__init__()
|
super(PromptDataset, self).__init__()
|
||||||
self.keyed_prompt = defaultdict(list)
|
self.keyed_prompt = defaultdict(list)
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
@ -30,11 +32,9 @@ class PromptDataset(Dataset):
|
|||||||
list_data_dict = list_data_dict[:max_datasets_size]
|
list_data_dict = list_data_dict[:max_datasets_size]
|
||||||
|
|
||||||
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
|
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
|
||||||
tokens = tokenizer(instructions,
|
tokens = tokenizer(
|
||||||
return_tensors='pt',
|
instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True
|
||||||
max_length=max_length,
|
)
|
||||||
padding='max_length',
|
|
||||||
truncation=True)
|
|
||||||
for k, tensor in tokens.items():
|
for k, tensor in tokens.items():
|
||||||
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
|
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
|
||||||
|
|
||||||
|
@ -20,44 +20,31 @@ class RmStaticDataset(Dataset):
|
|||||||
|
|
||||||
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.end_token = tokenizer.eos_token \
|
self.end_token = tokenizer.eos_token if special_token is None else special_token
|
||||||
if special_token is None else special_token
|
|
||||||
|
|
||||||
chosen = [
|
chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||||
data["prompt"] + data["chosen"] + self.end_token
|
chosen_token = tokenizer(
|
||||||
for data in tqdm(dataset, disable=not is_rank_0())
|
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||||
]
|
)
|
||||||
chosen_token = tokenizer(chosen,
|
self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt")
|
|
||||||
self.chosen = {
|
|
||||||
"input_ids": chosen_token["input_ids"],
|
|
||||||
"attention_mask": chosen_token["attention_mask"]
|
|
||||||
}
|
|
||||||
|
|
||||||
reject = [
|
reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||||
data["prompt"] + data["rejected"] + self.end_token
|
reject_token = tokenizer(
|
||||||
for data in tqdm(dataset, disable=not is_rank_0())
|
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||||
]
|
)
|
||||||
reject_token = tokenizer(reject,
|
self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt")
|
|
||||||
self.reject = {
|
|
||||||
"input_ids": reject_token["input_ids"],
|
|
||||||
"attention_mask": reject_token["attention_mask"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
length = self.chosen["input_ids"].shape[0]
|
length = self.chosen["input_ids"].shape[0]
|
||||||
return length
|
return length
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
|
return (
|
||||||
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
|
self.chosen["input_ids"][idx],
|
||||||
|
self.chosen["attention_mask"][idx],
|
||||||
|
self.reject["input_ids"][idx],
|
||||||
|
self.reject["attention_mask"][idx],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Anthropic/hh-rlhf
|
# Anthropic/hh-rlhf
|
||||||
@ -74,41 +61,28 @@ class HhRlhfDataset(Dataset):
|
|||||||
|
|
||||||
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.end_token = tokenizer.eos_token \
|
self.end_token = tokenizer.eos_token if special_token is None else special_token
|
||||||
if special_token is None else special_token
|
|
||||||
|
|
||||||
chosen = [
|
chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||||
data["chosen"] + self.end_token
|
chosen_token = tokenizer(
|
||||||
for data in tqdm(dataset, disable=not is_rank_0())
|
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||||
]
|
)
|
||||||
chosen_token = tokenizer(chosen,
|
self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt")
|
|
||||||
self.chosen = {
|
|
||||||
"input_ids": chosen_token["input_ids"],
|
|
||||||
"attention_mask": chosen_token["attention_mask"]
|
|
||||||
}
|
|
||||||
|
|
||||||
reject = [
|
reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||||
data["rejected"] + self.end_token
|
reject_token = tokenizer(
|
||||||
for data in tqdm(dataset, disable=not is_rank_0())
|
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||||
]
|
)
|
||||||
reject_token = tokenizer(reject,
|
self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt")
|
|
||||||
self.reject = {
|
|
||||||
"input_ids": reject_token["input_ids"],
|
|
||||||
"attention_mask": reject_token["attention_mask"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
length = self.chosen["input_ids"].shape[0]
|
length = self.chosen["input_ids"].shape[0]
|
||||||
return length
|
return length
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
|
return (
|
||||||
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
|
self.chosen["input_ids"][idx],
|
||||||
|
self.chosen["attention_mask"][idx],
|
||||||
|
self.reject["input_ids"][idx],
|
||||||
|
self.reject["attention_mask"][idx],
|
||||||
|
)
|
||||||
|
@ -16,10 +16,11 @@ import copy
|
|||||||
from typing import Dict, Sequence, Tuple
|
from typing import Dict, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
from .utils import is_rank_0, jload
|
from .utils import is_rank_0, jload
|
||||||
@ -28,32 +29,33 @@ logger = get_dist_logger()
|
|||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
PROMPT_DICT = {
|
PROMPT_DICT = {
|
||||||
"prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. "
|
"prompt_input": (
|
||||||
"Write a response that appropriately completes the request.\n\n"
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
"prompt_no_input": ("Below is an instruction that describes a task. "
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
|
||||||
"Write a response that appropriately completes the request.\n\n"
|
),
|
||||||
"### Instruction:\n{instruction}\n\n### Response:"),
|
"prompt_no_input": (
|
||||||
|
"Below is an instruction that describes a task. "
|
||||||
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
|
"### Instruction:\n{instruction}\n\n### Response:"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _preprocess(sources: Sequence[str],
|
def _preprocess(
|
||||||
targets: Sequence[str],
|
sources: Sequence[str],
|
||||||
tokenizer: PreTrainedTokenizer,
|
targets: Sequence[str],
|
||||||
max_length: int,
|
tokenizer: PreTrainedTokenizer,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
max_length: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Preprocess the data by tokenizing."""
|
"""Preprocess the data by tokenizing."""
|
||||||
sequences = [s + t for s, t in zip(sources, targets)]
|
sequences = [s + t for s, t in zip(sources, targets)]
|
||||||
sequences_token = tokenizer(sequences,
|
sequences_token = tokenizer(
|
||||||
max_length=max_length,
|
sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||||
padding="max_length",
|
)
|
||||||
truncation=True,
|
sources_token = tokenizer(
|
||||||
return_tensors="pt")
|
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||||
sources_token = tokenizer(sources,
|
)
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt")
|
|
||||||
|
|
||||||
labels = copy.deepcopy(sequences_token["input_ids"])
|
labels = copy.deepcopy(sequences_token["input_ids"])
|
||||||
for i in range(labels.shape[0]):
|
for i in range(labels.shape[0]):
|
||||||
@ -64,18 +66,19 @@ def _preprocess(sources: Sequence[str],
|
|||||||
labels[i][:source_len] = IGNORE_INDEX
|
labels[i][:source_len] = IGNORE_INDEX
|
||||||
elif tokenizer.padding_side == "left":
|
elif tokenizer.padding_side == "left":
|
||||||
# |pad|prompt|completion|eos|
|
# |pad|prompt|completion|eos|
|
||||||
labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX
|
labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX
|
||||||
else:
|
else:
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
|
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
|
||||||
|
|
||||||
|
|
||||||
def _preprocess_chatglm(sources: Sequence[str],
|
def _preprocess_chatglm(
|
||||||
targets: Sequence[str],
|
sources: Sequence[str],
|
||||||
tokenizer: PreTrainedTokenizer,
|
targets: Sequence[str],
|
||||||
max_length: int,
|
tokenizer: PreTrainedTokenizer,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
max_length: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Preprocess the data by tokenizing.
|
Preprocess the data by tokenizing.
|
||||||
None for attention mask, ChatGLM will calculate attention mask according to input ids
|
None for attention mask, ChatGLM will calculate attention mask according to input ids
|
||||||
@ -90,15 +93,15 @@ def _preprocess_chatglm(sources: Sequence[str],
|
|||||||
# truncate
|
# truncate
|
||||||
sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
|
sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
|
||||||
truncate_length = max(0, len(input_id) - max_length)
|
truncate_length = max(0, len(input_id) - max_length)
|
||||||
input_id = input_id[truncate_length: ]
|
input_id = input_id[truncate_length:]
|
||||||
if truncate_length == len(source_id) + 1:
|
if truncate_length == len(source_id) + 1:
|
||||||
input_id = sp_token_list + input_id[1: ]
|
input_id = sp_token_list + input_id[1:]
|
||||||
elif truncate_length > len(source_id) + 1:
|
elif truncate_length > len(source_id) + 1:
|
||||||
input_id = sp_token_list + input_id[2: ]
|
input_id = sp_token_list + input_id[2:]
|
||||||
|
|
||||||
context_length = input_id.index(tokenizer.bos_token_id)
|
context_length = input_id.index(tokenizer.bos_token_id)
|
||||||
mask_position = context_length - 1
|
mask_position = context_length - 1
|
||||||
label = [IGNORE_INDEX] * context_length + input_id[mask_position+1:]
|
label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :]
|
||||||
|
|
||||||
pad_len = max_length - len(input_id)
|
pad_len = max_length - len(input_id)
|
||||||
input_id = input_id + [tokenizer.pad_token_id] * pad_len
|
input_id = input_id + [tokenizer.pad_token_id] * pad_len
|
||||||
@ -117,25 +120,18 @@ class SFTDataset(Dataset):
|
|||||||
max_length: max length of input
|
max_length: max length of input
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None:
|
||||||
dataset: Dict,
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
max_length: int = 512
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_ids = []
|
self.input_ids = []
|
||||||
|
|
||||||
sources = [data["prompt"] for data in dataset]
|
sources = [data["prompt"] for data in dataset]
|
||||||
targets = [
|
targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||||
data["completion"] + tokenizer.eos_token
|
|
||||||
for data in tqdm(dataset, disable=not is_rank_0())
|
|
||||||
]
|
|
||||||
if isinstance(tokenizer, ChatGLMTokenizer):
|
if isinstance(tokenizer, ChatGLMTokenizer):
|
||||||
self.input_ids, self.labels, self.attention_mask = \
|
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
|
||||||
_preprocess_chatglm(sources, targets, tokenizer, max_length)
|
sources, targets, tokenizer, max_length
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.input_ids, self.labels, self.attention_mask = \
|
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
|
||||||
_preprocess(sources, targets, tokenizer, max_length)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
length = self.input_ids.shape[0]
|
length = self.input_ids.shape[0]
|
||||||
@ -143,22 +139,17 @@ class SFTDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
if self.attention_mask is not None:
|
if self.attention_mask is not None:
|
||||||
return dict(input_ids=self.input_ids[idx],
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
||||||
labels=self.labels[idx],
|
|
||||||
attention_mask=self.attention_mask[idx])
|
|
||||||
else:
|
else:
|
||||||
return dict(input_ids=self.input_ids[idx],
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
||||||
labels=self.labels[idx])
|
|
||||||
|
|
||||||
|
|
||||||
class SupervisedDataset(Dataset):
|
class SupervisedDataset(Dataset):
|
||||||
"""Dataset for supervised fine-tuning."""
|
"""Dataset for supervised fine-tuning."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
data_path: str,
|
self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512
|
||||||
tokenizer: PreTrainedTokenizer,
|
):
|
||||||
max_datasets_size: int = None,
|
|
||||||
max_length: int = 512):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logger.info("Loading data...")
|
logger.info("Loading data...")
|
||||||
list_data_dict = jload(data_path)
|
list_data_dict = jload(data_path)
|
||||||
@ -174,18 +165,15 @@ class SupervisedDataset(Dataset):
|
|||||||
prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
|
prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
|
||||||
for example in list_data_dict
|
for example in list_data_dict
|
||||||
]
|
]
|
||||||
targets = [
|
targets = [example["output"] + tokenizer.eos_token for example in list_data_dict]
|
||||||
example['output'] + tokenizer.eos_token
|
|
||||||
for example in list_data_dict
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info("Tokenizing inputs... This may take some time...")
|
logger.info("Tokenizing inputs... This may take some time...")
|
||||||
if isinstance(tokenizer, ChatGLMTokenizer):
|
if isinstance(tokenizer, ChatGLMTokenizer):
|
||||||
self.input_ids, self.labels, self.attention_mask = \
|
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
|
||||||
_preprocess_chatglm(sources, targets, tokenizer, max_length)
|
sources, targets, tokenizer, max_length
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.input_ids, self.labels, self.attention_mask = \
|
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
|
||||||
_preprocess(sources, targets, tokenizer, max_length)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
length = self.input_ids.shape[0]
|
length = self.input_ids.shape[0]
|
||||||
@ -193,9 +181,6 @@ class SupervisedDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
if self.attention_mask is not None:
|
if self.attention_mask is not None:
|
||||||
return dict(input_ids=self.input_ids[idx],
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
||||||
labels=self.labels[idx],
|
|
||||||
attention_mask=self.attention_mask[idx])
|
|
||||||
else:
|
else:
|
||||||
return dict(input_ids=self.input_ids[idx],
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
||||||
labels=self.labels[idx])
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .base import ExperienceBuffer
|
from .base import ExperienceBuffer
|
||||||
from .naive import NaiveExperienceBuffer
|
from .naive import NaiveExperienceBuffer
|
||||||
|
|
||||||
__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer']
|
__all__ = ["ExperienceBuffer", "NaiveExperienceBuffer"]
|
||||||
|
@ -7,9 +7,9 @@ from coati.experience_maker.base import Experience
|
|||||||
class ExperienceBuffer(ABC):
|
class ExperienceBuffer(ABC):
|
||||||
"""Experience buffer base class. It stores experience.
|
"""Experience buffer base class. It stores experience.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample_batch_size (int): Batch size when sampling.
|
sample_batch_size (int): Batch size when sampling.
|
||||||
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
|
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
|
||||||
|
@ -11,23 +11,23 @@ from .utils import BufferItem, make_experience_batch, split_experience_batch
|
|||||||
class NaiveExperienceBuffer(ExperienceBuffer):
|
class NaiveExperienceBuffer(ExperienceBuffer):
|
||||||
"""Naive experience buffer class. It stores experience.
|
"""Naive experience buffer class. It stores experience.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample_batch_size (int): Batch size when sampling.
|
sample_batch_size (int): Batch size when sampling.
|
||||||
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
||||||
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
|
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
|
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
|
||||||
super().__init__(sample_batch_size, limit)
|
super().__init__(sample_batch_size, limit)
|
||||||
self.cpu_offload = cpu_offload
|
self.cpu_offload = cpu_offload
|
||||||
self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
|
self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||||
# TODO(ver217): add prefetch
|
# TODO(ver217): add prefetch
|
||||||
self.items: List[BufferItem] = []
|
self.items: List[BufferItem] = []
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def append(self, experience: Experience) -> None:
|
def append(self, experience: Experience) -> None:
|
||||||
if self.cpu_offload:
|
if self.cpu_offload:
|
||||||
experience.to_device(torch.device('cpu'))
|
experience.to_device(torch.device("cpu"))
|
||||||
items = split_experience_batch(experience)
|
items = split_experience_batch(experience)
|
||||||
self.items.extend(items)
|
self.items.extend(items)
|
||||||
if self.limit > 0:
|
if self.limit > 0:
|
||||||
|
@ -21,6 +21,7 @@ class BufferItem:
|
|||||||
|
|
||||||
"A" is the number of actions.
|
"A" is the number of actions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.Tensor
|
sequences: torch.Tensor
|
||||||
action_log_probs: torch.Tensor
|
action_log_probs: torch.Tensor
|
||||||
values: torch.Tensor
|
values: torch.Tensor
|
||||||
@ -33,8 +34,7 @@ class BufferItem:
|
|||||||
def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
||||||
batch_size = experience.sequences.size(0)
|
batch_size = experience.sequences.size(0)
|
||||||
batch_kwargs = [{} for _ in range(batch_size)]
|
batch_kwargs = [{} for _ in range(batch_size)]
|
||||||
keys = ('sequences', 'action_log_probs', 'values',
|
keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
|
||||||
'reward', 'advantages', 'attention_mask', 'action_mask')
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
value = getattr(experience, key)
|
value = getattr(experience, key)
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
@ -49,22 +49,21 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
|||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
|
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor:
|
||||||
assert side in ('left', 'right')
|
assert side in ("left", "right")
|
||||||
max_len = max(seq.size(0) for seq in sequences)
|
max_len = max(seq.size(0) for seq in sequences)
|
||||||
padded_sequences = []
|
padded_sequences = []
|
||||||
for seq in sequences:
|
for seq in sequences:
|
||||||
pad_len = max_len - seq.size(0)
|
pad_len = max_len - seq.size(0)
|
||||||
padding = (pad_len, 0) if side == 'left' else (0, pad_len)
|
padding = (pad_len, 0) if side == "left" else (0, pad_len)
|
||||||
padded_sequences.append(F.pad(seq, padding))
|
padded_sequences.append(F.pad(seq, padding))
|
||||||
return torch.stack(padded_sequences, dim=0)
|
return torch.stack(padded_sequences, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def make_experience_batch(items: List[BufferItem]) -> Experience:
|
def make_experience_batch(items: List[BufferItem]) -> Experience:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
to_pad_keys = set(('action_log_probs', 'action_mask'))
|
to_pad_keys = set(("action_log_probs", "action_mask"))
|
||||||
keys = ('sequences', 'action_log_probs', 'values',
|
keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
|
||||||
'reward', 'advantages', 'attention_mask', 'action_mask')
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
vals = [getattr(item, key) for item in items]
|
vals = [getattr(item, key) for item in items]
|
||||||
if key in to_pad_keys:
|
if key in to_pad_keys:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .base import Experience, ExperienceMaker
|
from .base import Experience, ExperienceMaker
|
||||||
from .naive import NaiveExperienceMaker
|
from .naive import NaiveExperienceMaker
|
||||||
|
|
||||||
__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker']
|
__all__ = ["Experience", "ExperienceMaker", "NaiveExperienceMaker"]
|
||||||
|
@ -24,6 +24,7 @@ class Experience:
|
|||||||
|
|
||||||
"A" is the number of actions.
|
"A" is the number of actions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.Tensor
|
sequences: torch.Tensor
|
||||||
action_log_probs: torch.Tensor
|
action_log_probs: torch.Tensor
|
||||||
values: torch.Tensor
|
values: torch.Tensor
|
||||||
@ -58,13 +59,9 @@ class Experience:
|
|||||||
|
|
||||||
|
|
||||||
class ExperienceMaker(ABC):
|
class ExperienceMaker(ABC):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self, actor: Actor, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, kl_coef: float = 0.1
|
||||||
actor: Actor,
|
) -> None:
|
||||||
critic: nn.Module,
|
|
||||||
reward_model: nn.Module,
|
|
||||||
initial_model: Actor,
|
|
||||||
kl_coef: float = 0.1) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
self.critic = critic
|
self.critic = critic
|
||||||
|
@ -23,22 +23,21 @@ class NaiveExperienceMaker(ExperienceMaker):
|
|||||||
|
|
||||||
# calculate auxiliary tensors
|
# calculate auxiliary tensors
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
pad_token_id = generate_kwargs.get('pad_token_id', None)
|
pad_token_id = generate_kwargs.get("pad_token_id", None)
|
||||||
if pad_token_id is not None:
|
if pad_token_id is not None:
|
||||||
attention_mask = sequences.not_equal(pad_token_id)\
|
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||||
.to(dtype=torch.long, device=sequences.device)
|
|
||||||
|
|
||||||
input_len = input_ids.size(1)
|
input_len = input_ids.size(1)
|
||||||
eos_token_id = generate_kwargs.get('eos_token_id', None)
|
eos_token_id = generate_kwargs.get("eos_token_id", None)
|
||||||
if eos_token_id is None:
|
if eos_token_id is None:
|
||||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||||
else:
|
else:
|
||||||
# left padding may be applied, only mask action
|
# left padding may be applied, only mask action
|
||||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||||
action_mask[:, :input_len] = False
|
action_mask[:, :input_len] = False
|
||||||
action_mask = action_mask[:, 1:]
|
action_mask = action_mask[:, 1:]
|
||||||
action_mask = action_mask[:, -(sequences.size(1) - input_len):]
|
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
|
||||||
num_actions = action_mask.size(1)
|
num_actions = action_mask.size(1)
|
||||||
|
|
||||||
actor_output = self.actor(sequences, attention_mask)
|
actor_output = self.actor(sequences, attention_mask)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from .wrapper import convert_to_xformer_model, recover_from_xformer_model
|
from .wrapper import convert_to_xformer_model, recover_from_xformer_model
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'convert_to_xformer_model',
|
"convert_to_xformer_model",
|
||||||
'recover_from_xformer_model',
|
"recover_from_xformer_model",
|
||||||
]
|
]
|
||||||
|
@ -21,11 +21,12 @@ class XOPTAttention(OPTAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]:
|
) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]:
|
||||||
if not self.training:
|
if not self.training:
|
||||||
return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask,
|
return super().forward(
|
||||||
output_attentions)
|
hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions
|
||||||
|
)
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask'
|
assert layer_head_mask is None, "Xformers attention does not support layer_head_mask"
|
||||||
assert not output_attentions, 'Xformers attention does not support output_attentions'
|
assert not output_attentions, "Xformers attention does not support output_attentions"
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
# for the decoder
|
# for the decoder
|
||||||
@ -69,12 +70,14 @@ class XOPTAttention(OPTAttention):
|
|||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
value_states = value_states.transpose(1, 2)
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
attn_output = xops.memory_efficient_attention(query_states,
|
attn_output = xops.memory_efficient_attention(
|
||||||
key_states,
|
query_states,
|
||||||
value_states,
|
key_states,
|
||||||
attn_bias=xops.LowerTriangularMask(),
|
value_states,
|
||||||
p=self.dropout if self.training else 0.0,
|
attn_bias=xops.LowerTriangularMask(),
|
||||||
scale=self.scaling)
|
p=self.dropout if self.training else 0.0,
|
||||||
|
scale=self.scaling,
|
||||||
|
)
|
||||||
|
|
||||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||||
# partitioned across GPUs when using tensor-parallelism.
|
# partitioned across GPUs when using tensor-parallelism.
|
||||||
|
@ -3,6 +3,13 @@ from .lora import LoRAModule, convert_to_lora_module
|
|||||||
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss',
|
"Actor",
|
||||||
'LoRAModule', 'convert_to_lora_module'
|
"Critic",
|
||||||
|
"RewardModel",
|
||||||
|
"PolicyLoss",
|
||||||
|
"ValueLoss",
|
||||||
|
"LogSigLoss",
|
||||||
|
"LogExpLoss",
|
||||||
|
"LoRAModule",
|
||||||
|
"convert_to_lora_module",
|
||||||
]
|
]
|
||||||
|
@ -18,9 +18,10 @@ def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
|
|||||||
Returns:
|
Returns:
|
||||||
nn.Module: the base model
|
nn.Module: the base model
|
||||||
"""
|
"""
|
||||||
assert isinstance(model, (Actor, Critic, RewardModel)), \
|
assert isinstance(
|
||||||
f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
|
model, (Actor, Critic, RewardModel)
|
||||||
|
), f"Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first."
|
||||||
return model.model
|
return model.model
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
|
__all__ = ["Actor", "Critic", "RewardModel", "get_base_model"]
|
||||||
|
@ -16,18 +16,17 @@ class Actor(LoRAModule):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
|
||||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.convert_to_lora()
|
self.convert_to_lora()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
**model_kwargs, # HACK: `generate` method may pass more kwargs
|
**model_kwargs, # HACK: `generate` method may pass more kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Returns model output.
|
"""Returns model output."""
|
||||||
"""
|
|
||||||
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
||||||
return output
|
return output
|
||||||
|
@ -23,22 +23,23 @@ class Critic(LoRAModule):
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
value_head: nn.Module,
|
value_head: nn.Module,
|
||||||
lora_rank: int = 0,
|
lora_rank: int = 0,
|
||||||
lora_train_bias: str = 'none',
|
lora_train_bias: str = "none",
|
||||||
use_action_mask: bool = False,
|
use_action_mask: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.value_head = value_head
|
self.value_head = value_head
|
||||||
self.use_action_mask = use_action_mask
|
self.use_action_mask = use_action_mask
|
||||||
self.convert_to_lora()
|
self.convert_to_lora()
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
sequences: torch.LongTensor,
|
self,
|
||||||
action_mask: Optional[torch.Tensor] = None,
|
sequences: torch.LongTensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
action_mask: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||||
last_hidden_states = outputs['last_hidden_state']
|
last_hidden_states = outputs["last_hidden_state"]
|
||||||
|
|
||||||
values = self.value_head(last_hidden_states).squeeze(-1)
|
values = self.value_head(last_hidden_states).squeeze(-1)
|
||||||
|
|
||||||
|
@ -17,11 +17,13 @@ class RewardModel(LoRAModule):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
model: nn.Module,
|
self,
|
||||||
value_head: Optional[nn.Module] = None,
|
model: nn.Module,
|
||||||
lora_rank: int = 0,
|
value_head: Optional[nn.Module] = None,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = "none",
|
||||||
|
) -> None:
|
||||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.convert_to_lora()
|
self.convert_to_lora()
|
||||||
@ -35,7 +37,7 @@ class RewardModel(LoRAModule):
|
|||||||
|
|
||||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||||
last_hidden_states = outputs['last_hidden_state']
|
last_hidden_states = outputs["last_hidden_state"]
|
||||||
values = self.value_head(last_hidden_states)[:, :-1]
|
values = self.value_head(last_hidden_states)[:, :-1]
|
||||||
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
||||||
return value
|
return value
|
||||||
|
@ -2,4 +2,4 @@ from .bloom_actor import BLOOMActor
|
|||||||
from .bloom_critic import BLOOMCritic
|
from .bloom_critic import BLOOMCritic
|
||||||
from .bloom_rm import BLOOMRM
|
from .bloom_rm import BLOOMRM
|
||||||
|
|
||||||
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
|
__all__ = ["BLOOMActor", "BLOOMCritic", "BLOOMRM"]
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
from transformers import BloomConfig, BloomForCausalLM
|
||||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
|
||||||
|
|
||||||
from ..base import Actor
|
from ..base import Actor
|
||||||
|
|
||||||
@ -18,12 +17,14 @@ class BLOOMActor(Actor):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: str = None,
|
self,
|
||||||
config: Optional[BloomConfig] = None,
|
pretrained: str = None,
|
||||||
checkpoint: bool = False,
|
config: Optional[BloomConfig] = None,
|
||||||
lora_rank: int = 0,
|
checkpoint: bool = False,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = "none",
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
from transformers import BloomConfig, BloomModel
|
||||||
|
|
||||||
from ..base import Critic
|
from ..base import Critic
|
||||||
|
|
||||||
@ -18,12 +17,14 @@ class BLOOMCritic(Critic):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: str = None,
|
self,
|
||||||
config: Optional[BloomConfig] = None,
|
pretrained: str = None,
|
||||||
lora_rank: int = 0,
|
config: Optional[BloomConfig] = None,
|
||||||
lora_train_bias: str = 'none',
|
lora_rank: int = 0,
|
||||||
**kwargs) -> None:
|
lora_train_bias: str = "none",
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = BloomModel.from_pretrained(pretrained)
|
model = BloomModel.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
from transformers import BloomConfig, BloomModel
|
||||||
|
|
||||||
from ..base import RewardModel
|
from ..base import RewardModel
|
||||||
|
|
||||||
@ -17,11 +17,13 @@ class BLOOMRM(RewardModel):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: str = None,
|
self,
|
||||||
config: Optional[BloomConfig] = None,
|
pretrained: str = None,
|
||||||
lora_rank: int = 0,
|
config: Optional[BloomConfig] = None,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = "none",
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = BloomModel.from_pretrained(pretrained)
|
model = BloomModel.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
from .chatglm_actor import ChatGLMActor
|
from .chatglm_actor import ChatGLMActor
|
||||||
|
|
||||||
__all__ = ['ChatGLMActor']
|
__all__ = ["ChatGLMActor"]
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
from ..base import Actor
|
||||||
from .configuration_chatglm import ChatGLMConfig
|
from .configuration_chatglm import ChatGLMConfig
|
||||||
from .modeling_chatglm import ChatGLMForConditionalGeneration
|
from .modeling_chatglm import ChatGLMForConditionalGeneration
|
||||||
|
|
||||||
from ..base import Actor
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMActor(Actor):
|
class ChatGLMActor(Actor):
|
||||||
"""
|
"""
|
||||||
@ -19,10 +17,9 @@ class ChatGLMActor(Actor):
|
|||||||
do not support lora for now.
|
do not support lora for now.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: str = None,
|
self, pretrained: str = None, config: Optional[ChatGLMConfig] = None, checkpoint: bool = False
|
||||||
config: Optional[ChatGLMConfig] = None,
|
) -> None:
|
||||||
checkpoint: bool = False) -> None:
|
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = ChatGLMForConditionalGeneration.from_pretrained(pretrained)
|
model = ChatGLMForConditionalGeneration.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
@ -31,4 +28,4 @@ class ChatGLMActor(Actor):
|
|||||||
model = ChatGLMForConditionalGeneration(ChatGLMConfig())
|
model = ChatGLMForConditionalGeneration(ChatGLMConfig())
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
super().__init__(model, lora_rank=0, lora_train_bias='none')
|
super().__init__(model, lora_rank=0, lora_train_bias="none")
|
||||||
|
@ -2,15 +2,14 @@
|
|||||||
This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py
|
This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py
|
||||||
"""
|
"""
|
||||||
"""Tokenization classes for ChatGLM."""
|
"""Tokenization classes for ChatGLM."""
|
||||||
from typing import List, Optional, Union
|
|
||||||
import os
|
import os
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
from transformers.utils import logging, PaddingStrategy
|
|
||||||
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
|
||||||
from typing import Dict
|
|
||||||
import sentencepiece as spm
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import sentencepiece as spm
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
from transformers.tokenization_utils_base import BatchEncoding, EncodedInput
|
||||||
|
from transformers.utils import PaddingStrategy, logging
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -52,11 +51,11 @@ class TextTokenizer:
|
|||||||
|
|
||||||
class SPTokenizer:
|
class SPTokenizer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_file,
|
vocab_file,
|
||||||
num_image_tokens=20000,
|
num_image_tokens=20000,
|
||||||
max_blank_length=80,
|
max_blank_length=80,
|
||||||
byte_fallback=True,
|
byte_fallback=True,
|
||||||
):
|
):
|
||||||
assert vocab_file is not None
|
assert vocab_file is not None
|
||||||
self.vocab_file = vocab_file
|
self.vocab_file = vocab_file
|
||||||
@ -100,9 +99,7 @@ class SPTokenizer:
|
|||||||
text = self._encode_whitespaces(text, max_len=self.max_blank_length)
|
text = self._encode_whitespaces(text, max_len=self.max_blank_length)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def encode(
|
def encode(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[int]:
|
||||||
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
|
||||||
) -> List[int]:
|
|
||||||
"""
|
"""
|
||||||
@param text: Text to encode.
|
@param text: Text to encode.
|
||||||
@param linebreak: Whether to encode newline (\n) in text.
|
@param linebreak: Whether to encode newline (\n) in text.
|
||||||
@ -136,9 +133,7 @@ class SPTokenizer:
|
|||||||
text = self.postprocess(text)
|
text = self.postprocess(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def tokenize(
|
def tokenize(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[str]:
|
||||||
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
"""
|
||||||
@param text: Text to encode.
|
@param text: Text to encode.
|
||||||
@param linebreak: Whether to encode newline (\n) in text.
|
@param linebreak: Whether to encode newline (\n) in text.
|
||||||
@ -181,20 +176,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_file,
|
vocab_file,
|
||||||
do_lower_case=False,
|
do_lower_case=False,
|
||||||
remove_space=False,
|
remove_space=False,
|
||||||
bos_token='<sop>',
|
bos_token="<sop>",
|
||||||
eos_token='<eop>',
|
eos_token="<eop>",
|
||||||
end_token='</s>',
|
end_token="</s>",
|
||||||
mask_token='[MASK]',
|
mask_token="[MASK]",
|
||||||
gmask_token='[gMASK]',
|
gmask_token="[gMASK]",
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
pad_token="<pad>",
|
pad_token="<pad>",
|
||||||
unk_token="<unk>",
|
unk_token="<unk>",
|
||||||
num_image_tokens=20000,
|
num_image_tokens=20000,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
do_lower_case=do_lower_case,
|
do_lower_case=do_lower_case,
|
||||||
@ -208,7 +203,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
pad_token=pad_token,
|
pad_token=pad_token,
|
||||||
unk_token=unk_token,
|
unk_token=unk_token,
|
||||||
num_image_tokens=num_image_tokens,
|
num_image_tokens=num_image_tokens,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
@ -243,11 +238,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
""" Returns vocab size """
|
"""Returns vocab size"""
|
||||||
return self.sp_tokenizer.num_tokens
|
return self.sp_tokenizer.num_tokens
|
||||||
|
|
||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
""" Returns vocab as a dict """
|
"""Returns vocab as a dict"""
|
||||||
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
||||||
vocab.update(self.added_tokens_encoder)
|
vocab.update(self.added_tokens_encoder)
|
||||||
return vocab
|
return vocab
|
||||||
@ -264,7 +259,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def _tokenize(self, text, **kwargs):
|
def _tokenize(self, text, **kwargs):
|
||||||
""" Returns a tokenized string. """
|
"""Returns a tokenized string."""
|
||||||
text = self.preprocess_text(text)
|
text = self.preprocess_text(text)
|
||||||
|
|
||||||
seq = self.sp_tokenizer.tokenize(text)
|
seq = self.sp_tokenizer.tokenize(text)
|
||||||
@ -274,11 +269,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
return self.sp_tokenizer.decode_tokens(tokens)
|
return self.sp_tokenizer.decode_tokens(tokens)
|
||||||
|
|
||||||
def _decode(
|
def _decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
|
||||||
self,
|
|
||||||
token_ids: Union[int, List[int]],
|
|
||||||
**kwargs
|
|
||||||
) -> str:
|
|
||||||
if isinstance(token_ids, int):
|
if isinstance(token_ids, int):
|
||||||
token_ids = [token_ids]
|
token_ids = [token_ids]
|
||||||
if len(token_ids) == 0:
|
if len(token_ids) == 0:
|
||||||
@ -288,7 +279,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
return super()._decode(token_ids, **kwargs)
|
return super()._decode(token_ids, **kwargs)
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
""" Converts a token (str) in an id using the vocab. """
|
"""Converts a token (str) in an id using the vocab."""
|
||||||
return self.sp_tokenizer[token]
|
return self.sp_tokenizer[token]
|
||||||
|
|
||||||
def _convert_id_to_token(self, index):
|
def _convert_id_to_token(self, index):
|
||||||
@ -309,13 +300,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
`Tuple(str)`: Paths to the files saved.
|
`Tuple(str)`: Paths to the files saved.
|
||||||
"""
|
"""
|
||||||
if os.path.isdir(save_directory):
|
if os.path.isdir(save_directory):
|
||||||
vocab_file = os.path.join(
|
vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
|
||||||
save_directory, self.vocab_files_names["vocab_file"]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
vocab_file = save_directory
|
vocab_file = save_directory
|
||||||
|
|
||||||
with open(self.vocab_file, 'rb') as fin:
|
with open(self.vocab_file, "rb") as fin:
|
||||||
proto_str = fin.read()
|
proto_str = fin.read()
|
||||||
|
|
||||||
with open(vocab_file, "wb") as writer:
|
with open(vocab_file, "wb") as writer:
|
||||||
@ -324,7 +313,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
return (vocab_file,)
|
return (vocab_file,)
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(
|
def build_inputs_with_special_tokens(
|
||||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||||
@ -343,19 +332,19 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||||
"""
|
"""
|
||||||
gmask_id = self.sp_tokenizer[self.gmask_token]
|
gmask_id = self.sp_tokenizer[self.gmask_token]
|
||||||
eos_id = self.sp_tokenizer[self.eos_token]
|
self.sp_tokenizer[self.eos_token]
|
||||||
token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
|
token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
|
||||||
if token_ids_1 is not None:
|
if token_ids_1 is not None:
|
||||||
token_ids_0 = token_ids_0 + token_ids_1
|
token_ids_0 = token_ids_0 + token_ids_1
|
||||||
return token_ids_0
|
return token_ids_0
|
||||||
|
|
||||||
def _pad(
|
def _pad(
|
||||||
self,
|
self,
|
||||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||||
pad_to_multiple_of: Optional[int] = None,
|
pad_to_multiple_of: Optional[int] = None,
|
||||||
return_attention_mask: Optional[bool] = None,
|
return_attention_mask: Optional[bool] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||||
@ -421,17 +410,23 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
mask_position = required_input.index(mask_token)
|
mask_position = required_input.index(mask_token)
|
||||||
position_ids[context_length:] = mask_position
|
position_ids[context_length:] = mask_position
|
||||||
block_position_ids = np.concatenate(
|
block_position_ids = np.concatenate(
|
||||||
[np.zeros(context_length, dtype=np.int64),
|
[
|
||||||
np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
|
np.zeros(context_length, dtype=np.int64),
|
||||||
|
np.arange(1, seq_length - context_length + 1, dtype=np.int64),
|
||||||
|
]
|
||||||
|
)
|
||||||
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
|
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
|
||||||
|
|
||||||
if needs_to_be_padded:
|
if needs_to_be_padded:
|
||||||
difference = max_length - len(required_input)
|
difference = max_length - len(required_input)
|
||||||
|
|
||||||
if "attention_mask" in encoded_inputs:
|
if "attention_mask" in encoded_inputs:
|
||||||
encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
|
encoded_inputs["attention_mask"] = np.pad(
|
||||||
pad_width=[(0, 0), (difference, 0), (difference, 0)],
|
encoded_inputs["attention_mask"],
|
||||||
mode='constant', constant_values=True)
|
pad_width=[(0, 0), (difference, 0), (difference, 0)],
|
||||||
|
mode="constant",
|
||||||
|
constant_values=True,
|
||||||
|
)
|
||||||
if "token_type_ids" in encoded_inputs:
|
if "token_type_ids" in encoded_inputs:
|
||||||
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
||||||
"token_type_ids"
|
"token_type_ids"
|
||||||
@ -439,8 +434,9 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||||||
if "special_tokens_mask" in encoded_inputs:
|
if "special_tokens_mask" in encoded_inputs:
|
||||||
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
||||||
if "position_ids" in encoded_inputs:
|
if "position_ids" in encoded_inputs:
|
||||||
encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
|
encoded_inputs["position_ids"] = np.pad(
|
||||||
pad_width=[(0, 0), (difference, 0)])
|
encoded_inputs["position_ids"], pad_width=[(0, 0), (difference, 0)]
|
||||||
|
)
|
||||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||||
|
|
||||||
return encoded_inputs
|
return encoded_inputs
|
@ -56,30 +56,29 @@ class ChatGLMConfig(PretrainedConfig):
|
|||||||
|
|
||||||
>>> # Accessing the model configuration
|
>>> # Accessing the model configuration
|
||||||
>>> configuration = model.config
|
>>> configuration = model.config
|
||||||
```
|
```"""
|
||||||
"""
|
|
||||||
model_type = "chatglm"
|
model_type = "chatglm"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size=130528,
|
vocab_size=130528,
|
||||||
hidden_size=4096,
|
hidden_size=4096,
|
||||||
num_layers=28,
|
num_layers=28,
|
||||||
num_attention_heads=32,
|
num_attention_heads=32,
|
||||||
layernorm_epsilon=1e-5,
|
layernorm_epsilon=1e-5,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
bos_token_id=130004,
|
bos_token_id=130004,
|
||||||
eos_token_id=130005,
|
eos_token_id=130005,
|
||||||
mask_token_id=130000,
|
mask_token_id=130000,
|
||||||
gmask_token_id=130001,
|
gmask_token_id=130001,
|
||||||
pad_token_id=3,
|
pad_token_id=3,
|
||||||
max_sequence_length=2048,
|
max_sequence_length=2048,
|
||||||
inner_hidden_size=16384,
|
inner_hidden_size=16384,
|
||||||
position_encoding_2d=True,
|
position_encoding_2d=True,
|
||||||
quantization_bit=0,
|
quantization_bit=0,
|
||||||
pre_seq_len=None,
|
pre_seq_len=None,
|
||||||
prefix_projection=False,
|
prefix_projection=False,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@ -99,9 +98,4 @@ class ChatGLMConfig(PretrainedConfig):
|
|||||||
self.pre_seq_len = pre_seq_len
|
self.pre_seq_len = pre_seq_len
|
||||||
self.prefix_projection = prefix_projection
|
self.prefix_projection = prefix_projection
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
@ -4,41 +4,40 @@ This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/mo
|
|||||||
|
|
||||||
""" PyTorch ChatGLM model. """
|
""" PyTorch ChatGLM model. """
|
||||||
|
|
||||||
import math
|
|
||||||
import copy
|
import copy
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||||
from torch.nn.utils import skip_init
|
from torch.nn.utils import skip_init
|
||||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
)
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
logging,
|
||||||
)
|
)
|
||||||
from transformers.modeling_outputs import (
|
|
||||||
BaseModelOutputWithPast,
|
|
||||||
CausalLMOutputWithPast,
|
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
|
||||||
)
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.utils import logging
|
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
|
||||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
|
||||||
|
|
||||||
from .configuration_chatglm import ChatGLMConfig
|
from .configuration_chatglm import ChatGLMConfig
|
||||||
|
|
||||||
# flags required to enable jit fusion kernels
|
# flags required to enable jit fusion kernels
|
||||||
|
|
||||||
if sys.platform != 'darwin':
|
if sys.platform != "darwin":
|
||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
torch._C._jit_set_profiling_executor(False)
|
torch._C._jit_set_profiling_executor(False)
|
||||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||||
@ -93,8 +92,8 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
|
|||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if any(
|
if any(
|
||||||
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||||
for n in name
|
for n in name
|
||||||
):
|
):
|
||||||
logger.info(f"Skipping {'/'.join(name)}")
|
logger.info(f"Skipping {'/'.join(name)}")
|
||||||
continue
|
continue
|
||||||
@ -127,7 +126,7 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
|
|||||||
array = np.transpose(array)
|
array = np.transpose(array)
|
||||||
try:
|
try:
|
||||||
assert (
|
assert (
|
||||||
pointer.shape == array.shape
|
pointer.shape == array.shape
|
||||||
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
e.args += (pointer.shape, array.shape)
|
e.args += (pointer.shape, array.shape)
|
||||||
@ -153,7 +152,7 @@ class PrefixEncoder(torch.nn.Module):
|
|||||||
self.trans = torch.nn.Sequential(
|
self.trans = torch.nn.Sequential(
|
||||||
torch.nn.Linear(config.hidden_size, config.hidden_size),
|
torch.nn.Linear(config.hidden_size, config.hidden_size),
|
||||||
torch.nn.Tanh(),
|
torch.nn.Tanh(),
|
||||||
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
|
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
|
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
|
||||||
@ -170,8 +169,7 @@ class PrefixEncoder(torch.nn.Module):
|
|||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def gelu_impl(x):
|
def gelu_impl(x):
|
||||||
"""OpenAI's gelu implementation."""
|
"""OpenAI's gelu implementation."""
|
||||||
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
|
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
|
||||||
(1.0 + 0.044715 * x * x)))
|
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
@ -181,21 +179,22 @@ def gelu(x):
|
|||||||
class RotaryEmbedding(torch.nn.Module):
|
class RotaryEmbedding(torch.nn.Module):
|
||||||
def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
|
def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||||
inv_freq = inv_freq.half()
|
inv_freq = inv_freq.half()
|
||||||
self.learnable = learnable
|
self.learnable = learnable
|
||||||
if learnable:
|
if learnable:
|
||||||
self.inv_freq = torch.nn.Parameter(inv_freq)
|
self.inv_freq = torch.nn.Parameter(inv_freq)
|
||||||
self.max_seq_len_cached = None
|
self.max_seq_len_cached = None
|
||||||
else:
|
else:
|
||||||
self.register_buffer('inv_freq', inv_freq)
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
self.max_seq_len_cached = None
|
self.max_seq_len_cached = None
|
||||||
self.cos_cached = None
|
self.cos_cached = None
|
||||||
self.sin_cached = None
|
self.sin_cached = None
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
def _load_from_state_dict(
|
||||||
error_msgs):
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def forward(self, x, seq_dim=1, seq_len=None):
|
def forward(self, x, seq_dim=1, seq_len=None):
|
||||||
@ -204,7 +203,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|||||||
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
|
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
|
||||||
self.max_seq_len_cached = None if self.learnable else seq_len
|
self.max_seq_len_cached = None if self.learnable else seq_len
|
||||||
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
|
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
|
||||||
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||||
if self.precision == torch.bfloat16:
|
if self.precision == torch.bfloat16:
|
||||||
@ -230,30 +229,31 @@ class RotaryEmbedding(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
||||||
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
|
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
|
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
|
||||||
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
|
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
|
||||||
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
|
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding(
|
||||||
F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
|
position_id, sin.squeeze(1)
|
||||||
|
).unsqueeze(2)
|
||||||
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
||||||
return q, k
|
return q, k
|
||||||
|
|
||||||
|
|
||||||
def attention_fn(
|
def attention_fn(
|
||||||
self,
|
self,
|
||||||
query_layer,
|
query_layer,
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
hidden_size_per_partition,
|
hidden_size_per_partition,
|
||||||
layer_id,
|
layer_id,
|
||||||
layer_past=None,
|
layer_past=None,
|
||||||
scaling_attention_score=True,
|
scaling_attention_score=True,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
):
|
):
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past[0], layer_past[1]
|
past_key, past_value = layer_past[0], layer_past[1]
|
||||||
@ -285,7 +285,9 @@ def attention_fn(
|
|||||||
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
||||||
|
|
||||||
matmul_result = torch.zeros(
|
matmul_result = torch.zeros(
|
||||||
1, 1, 1,
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
dtype=query_layer.dtype,
|
dtype=query_layer.dtype,
|
||||||
device=query_layer.device,
|
device=query_layer.device,
|
||||||
)
|
)
|
||||||
@ -355,9 +357,17 @@ def default_init(cls, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
class SelfAttention(torch.nn.Module):
|
class SelfAttention(torch.nn.Module):
|
||||||
def __init__(self, hidden_size, num_attention_heads,
|
def __init__(
|
||||||
layer_id, hidden_size_per_attention_head=None, bias=True,
|
self,
|
||||||
params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
|
hidden_size,
|
||||||
|
num_attention_heads,
|
||||||
|
layer_id,
|
||||||
|
hidden_size_per_attention_head=None,
|
||||||
|
bias=True,
|
||||||
|
params_dtype=torch.float,
|
||||||
|
position_encoding_2d=True,
|
||||||
|
empty_init=True,
|
||||||
|
):
|
||||||
if empty_init:
|
if empty_init:
|
||||||
init_method = skip_init
|
init_method = skip_init
|
||||||
else:
|
else:
|
||||||
@ -410,8 +420,7 @@ class SelfAttention(torch.nn.Module):
|
|||||||
attention_scores.masked_fill_(attention_mask, -10000.0)
|
attention_scores.masked_fill_(attention_mask, -10000.0)
|
||||||
return attention_scores
|
return attention_scores
|
||||||
|
|
||||||
def split_tensor_along_last_dim(self, tensor, num_partitions,
|
def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False):
|
||||||
contiguous_split_chunks=False):
|
|
||||||
"""Split a tensor along its last dimension.
|
"""Split a tensor along its last dimension.
|
||||||
Arguments:
|
Arguments:
|
||||||
tensor: input tensor.
|
tensor: input tensor.
|
||||||
@ -431,14 +440,14 @@ class SelfAttention(torch.nn.Module):
|
|||||||
return tensor_list
|
return tensor_list
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
position_ids,
|
position_ids,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
layer_id,
|
layer_id,
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
hidden_states: [seq_len, batch, hidden_size]
|
hidden_states: [seq_len, batch, hidden_size]
|
||||||
@ -462,8 +471,10 @@ class SelfAttention(torch.nn.Module):
|
|||||||
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
|
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
|
||||||
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
|
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
|
||||||
cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
|
cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
|
||||||
position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
|
position_ids, block_position_ids = (
|
||||||
position_ids[:, 1, :].transpose(0, 1).contiguous()
|
position_ids[:, 0, :].transpose(0, 1).contiguous(),
|
||||||
|
position_ids[:, 1, :].transpose(0, 1).contiguous(),
|
||||||
|
)
|
||||||
q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
|
q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
|
||||||
q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
|
q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
|
||||||
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
|
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
|
||||||
@ -484,7 +495,7 @@ class SelfAttention(torch.nn.Module):
|
|||||||
hidden_size_per_partition=self.hidden_size_per_partition,
|
hidden_size_per_partition=self.hidden_size_per_partition,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
use_cache=use_cache
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = self.dense(context_layer)
|
output = self.dense(context_layer)
|
||||||
@ -509,8 +520,16 @@ class GEGLU(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GLU(torch.nn.Module):
|
class GLU(torch.nn.Module):
|
||||||
def __init__(self, hidden_size, inner_hidden_size=None,
|
def __init__(
|
||||||
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
|
self,
|
||||||
|
hidden_size,
|
||||||
|
inner_hidden_size=None,
|
||||||
|
layer_id=None,
|
||||||
|
bias=True,
|
||||||
|
activation_func=gelu,
|
||||||
|
params_dtype=torch.float,
|
||||||
|
empty_init=True,
|
||||||
|
):
|
||||||
super(GLU, self).__init__()
|
super(GLU, self).__init__()
|
||||||
if empty_init:
|
if empty_init:
|
||||||
init_method = skip_init
|
init_method = skip_init
|
||||||
@ -557,19 +576,19 @@ class GLU(torch.nn.Module):
|
|||||||
|
|
||||||
class GLMBlock(torch.nn.Module):
|
class GLMBlock(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
layernorm_epsilon,
|
layernorm_epsilon,
|
||||||
layer_id,
|
layer_id,
|
||||||
inner_hidden_size=None,
|
inner_hidden_size=None,
|
||||||
hidden_size_per_attention_head=None,
|
hidden_size_per_attention_head=None,
|
||||||
layernorm=LayerNorm,
|
layernorm=LayerNorm,
|
||||||
use_bias=True,
|
use_bias=True,
|
||||||
params_dtype=torch.float,
|
params_dtype=torch.float,
|
||||||
num_layers=28,
|
num_layers=28,
|
||||||
position_encoding_2d=True,
|
position_encoding_2d=True,
|
||||||
empty_init=True
|
empty_init=True,
|
||||||
):
|
):
|
||||||
super(GLMBlock, self).__init__()
|
super(GLMBlock, self).__init__()
|
||||||
# Set output layer initialization if not provided.
|
# Set output layer initialization if not provided.
|
||||||
@ -590,7 +609,7 @@ class GLMBlock(torch.nn.Module):
|
|||||||
bias=use_bias,
|
bias=use_bias,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
position_encoding_2d=self.position_encoding_2d,
|
position_encoding_2d=self.position_encoding_2d,
|
||||||
empty_init=empty_init
|
empty_init=empty_init,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Layernorm on the input data.
|
# Layernorm on the input data.
|
||||||
@ -605,18 +624,18 @@ class GLMBlock(torch.nn.Module):
|
|||||||
bias=use_bias,
|
bias=use_bias,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
empty_init=empty_init
|
empty_init=empty_init,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
position_ids,
|
position_ids,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
layer_id,
|
layer_id,
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
hidden_states: [seq_len, batch, hidden_size]
|
hidden_states: [seq_len, batch, hidden_size]
|
||||||
@ -635,7 +654,7 @@ class GLMBlock(torch.nn.Module):
|
|||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
attention_output = attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
@ -702,10 +721,15 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|||||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
||||||
for i, context_length in enumerate(context_lengths):
|
for i, context_length in enumerate(context_lengths):
|
||||||
position_ids[i, context_length:] = mask_positions[i]
|
position_ids[i, context_length:] = mask_positions[i]
|
||||||
block_position_ids = [torch.cat((
|
block_position_ids = [
|
||||||
torch.zeros(context_length, dtype=torch.long, device=device),
|
torch.cat(
|
||||||
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
(
|
||||||
)) for context_length in context_lengths]
|
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||||
|
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for context_length in context_lengths
|
||||||
|
]
|
||||||
block_position_ids = torch.stack(block_position_ids, dim=0)
|
block_position_ids = torch.stack(block_position_ids, dim=0)
|
||||||
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||||
else:
|
else:
|
||||||
@ -823,9 +847,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||||||
self.prefix_projection = config.prefix_projection
|
self.prefix_projection = config.prefix_projection
|
||||||
|
|
||||||
self.word_embeddings = init_method(
|
self.word_embeddings = init_method(
|
||||||
torch.nn.Embedding,
|
torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype
|
||||||
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
|
||||||
dtype=self.params_dtype
|
|
||||||
)
|
)
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
@ -841,12 +863,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||||||
use_bias=True,
|
use_bias=True,
|
||||||
params_dtype=self.params_dtype,
|
params_dtype=self.params_dtype,
|
||||||
position_encoding_2d=self.position_encoding_2d,
|
position_encoding_2d=self.position_encoding_2d,
|
||||||
empty_init=empty_init
|
empty_init=empty_init,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layers = torch.nn.ModuleList(
|
self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)])
|
||||||
[get_layer(layer_id) for layer_id in range(self.num_layers)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Final layer norm before output.
|
# Final layer norm before output.
|
||||||
self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
|
self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
|
||||||
@ -876,7 +896,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||||||
self.pre_seq_len,
|
self.pre_seq_len,
|
||||||
self.num_layers * 2,
|
self.num_layers * 2,
|
||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
self.hidden_size // self.num_attention_heads
|
self.hidden_size // self.num_attention_heads,
|
||||||
)
|
)
|
||||||
# seq_len, b, nh, hidden_size
|
# seq_len, b, nh, hidden_size
|
||||||
past_key_values = self.dropout(past_key_values)
|
past_key_values = self.dropout(past_key_values)
|
||||||
@ -891,18 +911,17 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
|
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@ -931,17 +950,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||||||
|
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
if self.pre_seq_len is not None:
|
if self.pre_seq_len is not None:
|
||||||
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
|
past_key_values = self.get_prompt(
|
||||||
dtype=inputs_embeds.dtype)
|
batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
past_key_values = tuple([None] * len(self.layers))
|
past_key_values = tuple([None] * len(self.layers))
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = self.get_masks(
|
attention_mask = self.get_masks(input_ids, device=input_ids.device)
|
||||||
input_ids,
|
|
||||||
device=input_ids.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
||||||
@ -955,15 +971,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||||||
use_gmasks.append(use_gmask)
|
use_gmasks.append(use_gmask)
|
||||||
|
|
||||||
position_ids = self.get_position_ids(
|
position_ids = self.get_position_ids(
|
||||||
input_ids,
|
input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks
|
||||||
mask_positions=mask_positions,
|
|
||||||
device=input_ids.device,
|
|
||||||
use_gmasks=use_gmasks
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pre_seq_len is not None and attention_mask is not None:
|
if self.pre_seq_len is not None and attention_mask is not None:
|
||||||
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
|
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
|
||||||
attention_mask.device)
|
attention_mask.device
|
||||||
|
)
|
||||||
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
||||||
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
||||||
|
|
||||||
@ -980,7 +994,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||||||
attention_mask = attention_mask.to(hidden_states.device)
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
layer_past = past_key_values[i]
|
layer_past = past_key_values[i]
|
||||||
@ -994,7 +1007,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||||||
torch.tensor(i),
|
torch.tensor(i),
|
||||||
layer_past,
|
layer_past,
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions
|
output_attentions,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_ret = layer(
|
layer_ret = layer(
|
||||||
@ -1004,7 +1017,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|||||||
layer_id=torch.tensor(i),
|
layer_id=torch.tensor(i),
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_ret[0]
|
hidden_states = layer_ret[0]
|
||||||
@ -1049,13 +1062,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||||||
|
|
||||||
self.transformer = ChatGLMModel(config, empty_init=empty_init)
|
self.transformer = ChatGLMModel(config, empty_init=empty_init)
|
||||||
|
|
||||||
self.lm_head = init_method(
|
self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half)
|
||||||
nn.Linear,
|
|
||||||
config.hidden_size,
|
|
||||||
config.vocab_size,
|
|
||||||
bias=False,
|
|
||||||
dtype=torch.half
|
|
||||||
)
|
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@ -1087,32 +1094,29 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||||||
attention_mask = model_kwargs["attention_mask"]
|
attention_mask = model_kwargs["attention_mask"]
|
||||||
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
||||||
attention_mask = torch.cat(
|
attention_mask = torch.cat(
|
||||||
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
|
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3
|
||||||
|
)
|
||||||
new_attention_mask = attention_mask[:, :, -1:].clone()
|
new_attention_mask = attention_mask[:, :, -1:].clone()
|
||||||
new_attention_mask[..., -1] = False
|
new_attention_mask[..., -1] = False
|
||||||
model_kwargs["attention_mask"] = torch.cat(
|
model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2)
|
||||||
[attention_mask, new_attention_mask], dim=2
|
|
||||||
)
|
|
||||||
|
|
||||||
# update position ids
|
# update position ids
|
||||||
if "position_ids" in model_kwargs:
|
if "position_ids" in model_kwargs:
|
||||||
position_ids = model_kwargs["position_ids"]
|
position_ids = model_kwargs["position_ids"]
|
||||||
new_position_id = position_ids[..., -1:].clone()
|
new_position_id = position_ids[..., -1:].clone()
|
||||||
new_position_id[:, 1, :] += 1
|
new_position_id[:, 1, :] += 1
|
||||||
model_kwargs["position_ids"] = torch.cat(
|
model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
|
||||||
[position_ids, new_position_id], dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
past: Optional[torch.Tensor] = None,
|
past: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
||||||
@ -1137,11 +1141,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||||||
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
|
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
|
||||||
if self.position_encoding_2d:
|
if self.position_encoding_2d:
|
||||||
position_ids = torch.tensor(
|
position_ids = torch.tensor(
|
||||||
[[mask_position, seq_length - context_length] for mask_position, context_length in
|
[
|
||||||
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
|
[mask_position, seq_length - context_length]
|
||||||
|
for mask_position, context_length in zip(mask_positions, context_lengths)
|
||||||
|
],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=input_ids.device,
|
||||||
|
).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
|
position_ids = torch.tensor(
|
||||||
device=input_ids.device).unsqueeze(-1)
|
[mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device
|
||||||
|
).unsqueeze(-1)
|
||||||
|
|
||||||
if past is None:
|
if past is None:
|
||||||
past = past_key_values
|
past = past_key_values
|
||||||
@ -1149,44 +1159,38 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||||||
"input_ids": last_token,
|
"input_ids": last_token,
|
||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"attention_mask": attention_mask
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
if attention_mask is not None and attention_mask.dtype != torch.bool:
|
if attention_mask is not None and attention_mask.dtype != torch.bool:
|
||||||
logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
|
logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = self.get_masks(
|
attention_mask = self.get_masks(input_ids, device=input_ids.device)
|
||||||
input_ids,
|
|
||||||
device=input_ids.device
|
|
||||||
)
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = self.get_position_ids(
|
position_ids = self.get_position_ids(
|
||||||
input_ids,
|
input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks
|
||||||
device=input_ids.device,
|
|
||||||
mask_positions=mask_positions,
|
|
||||||
use_gmasks=use_gmasks
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"attention_mask": attention_mask
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
@ -1235,7 +1239,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(
|
def _reorder_cache(
|
||||||
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
||||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
||||||
"""
|
"""
|
||||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
||||||
@ -1268,15 +1272,33 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
def chat(
|
||||||
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
self,
|
||||||
|
tokenizer,
|
||||||
|
query: str,
|
||||||
|
history: List[Tuple[str, str]] = None,
|
||||||
|
max_length: int = 2048,
|
||||||
|
num_beams=1,
|
||||||
|
do_sample=True,
|
||||||
|
top_p=0.7,
|
||||||
|
temperature=0.95,
|
||||||
|
logits_processor=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
if logits_processor is None:
|
if logits_processor is None:
|
||||||
logits_processor = LogitsProcessorList()
|
logits_processor = LogitsProcessorList()
|
||||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
gen_kwargs = {
|
||||||
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
"max_length": max_length,
|
||||||
|
"num_beams": num_beams,
|
||||||
|
"do_sample": do_sample,
|
||||||
|
"top_p": top_p,
|
||||||
|
"temperature": temperature,
|
||||||
|
"logits_processor": logits_processor,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
if not history:
|
if not history:
|
||||||
prompt = query
|
prompt = query
|
||||||
else:
|
else:
|
||||||
@ -1287,22 +1309,38 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||||||
inputs = tokenizer([prompt], return_tensors="pt")
|
inputs = tokenizer([prompt], return_tensors="pt")
|
||||||
inputs = inputs.to(self.device)
|
inputs = inputs.to(self.device)
|
||||||
outputs = self.generate(**inputs, **gen_kwargs)
|
outputs = self.generate(**inputs, **gen_kwargs)
|
||||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
|
||||||
response = tokenizer.decode(outputs)
|
response = tokenizer.decode(outputs)
|
||||||
response = self.process_response(response)
|
response = self.process_response(response)
|
||||||
history = history + [(query, response)]
|
history = history + [(query, response)]
|
||||||
return response, history
|
return response, history
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
|
def stream_chat(
|
||||||
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
self,
|
||||||
|
tokenizer,
|
||||||
|
query: str,
|
||||||
|
history: List[Tuple[str, str]] = None,
|
||||||
|
max_length: int = 2048,
|
||||||
|
do_sample=True,
|
||||||
|
top_p=0.7,
|
||||||
|
temperature=0.95,
|
||||||
|
logits_processor=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
if logits_processor is None:
|
if logits_processor is None:
|
||||||
logits_processor = LogitsProcessorList()
|
logits_processor = LogitsProcessorList()
|
||||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||||
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
gen_kwargs = {
|
||||||
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
"max_length": max_length,
|
||||||
|
"do_sample": do_sample,
|
||||||
|
"top_p": top_p,
|
||||||
|
"temperature": temperature,
|
||||||
|
"logits_processor": logits_processor,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
if not history:
|
if not history:
|
||||||
prompt = query
|
prompt = query
|
||||||
else:
|
else:
|
||||||
@ -1313,7 +1351,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||||||
inputs = tokenizer([prompt], return_tensors="pt")
|
inputs = tokenizer([prompt], return_tensors="pt")
|
||||||
inputs = inputs.to(self.device)
|
inputs = inputs.to(self.device)
|
||||||
for outputs in self.stream_generate(**inputs, **gen_kwargs):
|
for outputs in self.stream_generate(**inputs, **gen_kwargs):
|
||||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
|
||||||
response = tokenizer.decode(outputs)
|
response = tokenizer.decode(outputs)
|
||||||
response = self.process_response(response)
|
response = self.process_response(response)
|
||||||
new_history = history + [(query, response)]
|
new_history = history + [(query, response)]
|
||||||
@ -1321,13 +1359,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def stream_generate(
|
def stream_generate(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
||||||
|
|
||||||
|
@ -16,9 +16,9 @@ except ImportError:
|
|||||||
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
||||||
|
|
||||||
|
|
||||||
def _prepare_logits_processor(top_k: Optional[int] = None,
|
def _prepare_logits_processor(
|
||||||
top_p: Optional[float] = None,
|
top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
|
||||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
) -> LogitsProcessorList:
|
||||||
processor_list = LogitsProcessorList()
|
processor_list = LogitsProcessorList()
|
||||||
if temperature is not None and temperature != 1.0:
|
if temperature is not None and temperature != 1.0:
|
||||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||||
@ -37,18 +37,20 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
|||||||
return unfinished_sequences.max() == 0
|
return unfinished_sequences.max() == 0
|
||||||
|
|
||||||
|
|
||||||
def _sample(model: Actor,
|
def _sample(
|
||||||
input_ids: torch.Tensor,
|
model: Actor,
|
||||||
max_length: int,
|
input_ids: torch.Tensor,
|
||||||
early_stopping: bool = False,
|
max_length: int,
|
||||||
eos_token_id: Optional[int] = None,
|
early_stopping: bool = False,
|
||||||
pad_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
top_k: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_k: Optional[int] = None,
|
||||||
temperature: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
temperature: Optional[float] = None,
|
||||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||||
**model_kwargs) -> torch.Tensor:
|
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||||
|
**model_kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
if input_ids.size(1) >= max_length:
|
if input_ids.size(1) >= max_length:
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
@ -56,11 +58,12 @@ def _sample(model: Actor,
|
|||||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||||
|
|
||||||
for _ in range(input_ids.size(1), max_length):
|
for _ in range(input_ids.size(1), max_length):
|
||||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
|
model_inputs = (
|
||||||
if prepare_inputs_fn is not None else {'input_ids': input_ids}
|
prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
|
||||||
|
)
|
||||||
outputs = model(**model_inputs)
|
outputs = model(**model_inputs)
|
||||||
|
|
||||||
next_token_logits = outputs['logits'][:, -1, :]
|
next_token_logits = outputs["logits"][:, -1, :]
|
||||||
# pre-process distribution
|
# pre-process distribution
|
||||||
next_token_logits = logits_processor(input_ids, next_token_logits)
|
next_token_logits = logits_processor(input_ids, next_token_logits)
|
||||||
# sample
|
# sample
|
||||||
@ -90,20 +93,22 @@ def _sample(model: Actor,
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(model: Actor,
|
def generate(
|
||||||
input_ids: torch.Tensor,
|
model: Actor,
|
||||||
max_length: int,
|
input_ids: torch.Tensor,
|
||||||
num_beams: int = 1,
|
max_length: int,
|
||||||
do_sample: bool = True,
|
num_beams: int = 1,
|
||||||
early_stopping: bool = False,
|
do_sample: bool = True,
|
||||||
eos_token_id: Optional[int] = None,
|
early_stopping: bool = False,
|
||||||
pad_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
top_k: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_k: Optional[int] = None,
|
||||||
temperature: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
temperature: Optional[float] = None,
|
||||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||||
**model_kwargs) -> torch.Tensor:
|
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||||
|
**model_kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
|
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -121,26 +126,28 @@ def generate(model: Actor,
|
|||||||
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
||||||
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
||||||
"""
|
"""
|
||||||
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
|
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
||||||
is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
|
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
||||||
is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
|
is_beam_gen_mode = (num_beams > 1) and do_sample is False
|
||||||
if is_greedy_gen_mode:
|
if is_greedy_gen_mode:
|
||||||
# run greedy search
|
# run greedy search
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
elif is_sample_gen_mode:
|
elif is_sample_gen_mode:
|
||||||
# run sample
|
# run sample
|
||||||
return _sample(model,
|
return _sample(
|
||||||
input_ids,
|
model,
|
||||||
max_length,
|
input_ids,
|
||||||
early_stopping=early_stopping,
|
max_length,
|
||||||
eos_token_id=eos_token_id,
|
early_stopping=early_stopping,
|
||||||
pad_token_id=pad_token_id,
|
eos_token_id=eos_token_id,
|
||||||
top_k=top_k,
|
pad_token_id=pad_token_id,
|
||||||
top_p=top_p,
|
top_k=top_k,
|
||||||
temperature=temperature,
|
top_p=top_p,
|
||||||
prepare_inputs_fn=prepare_inputs_fn,
|
temperature=temperature,
|
||||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
prepare_inputs_fn=prepare_inputs_fn,
|
||||||
**model_kwargs)
|
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
elif is_beam_gen_mode:
|
elif is_beam_gen_mode:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
|
@ -2,4 +2,4 @@ from .gpt_actor import GPTActor
|
|||||||
from .gpt_critic import GPTCritic
|
from .gpt_critic import GPTCritic
|
||||||
from .gpt_rm import GPTRM
|
from .gpt_rm import GPTRM
|
||||||
|
|
||||||
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
|
__all__ = ["GPTActor", "GPTCritic", "GPTRM"]
|
||||||
|
@ -18,13 +18,15 @@ class GPTActor(Actor):
|
|||||||
lora_train_bias (str): Bias training strategy for the LoRa layer.
|
lora_train_bias (str): Bias training strategy for the LoRa layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: Optional[str] = None,
|
self,
|
||||||
config: Optional[GPT2Config] = None,
|
pretrained: Optional[str] = None,
|
||||||
checkpoint: bool = False,
|
config: Optional[GPT2Config] = None,
|
||||||
lora_rank: int = 0,
|
checkpoint: bool = False,
|
||||||
lora_train_bias: str = 'none',
|
lora_rank: int = 0,
|
||||||
**kwargs) -> None:
|
lora_train_bias: str = "none",
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -18,12 +18,14 @@ class GPTCritic(Critic):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: Optional[str] = None,
|
self,
|
||||||
config: Optional[GPT2Config] = None,
|
pretrained: Optional[str] = None,
|
||||||
lora_rank: int = 0,
|
config: Optional[GPT2Config] = None,
|
||||||
lora_train_bias: str = 'none',
|
lora_rank: int = 0,
|
||||||
**kwargs) -> None:
|
lora_train_bias: str = "none",
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = GPT2Model.from_pretrained(pretrained)
|
model = GPT2Model.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -18,11 +18,13 @@ class GPTRM(RewardModel):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: Optional[str] = None,
|
self,
|
||||||
config: Optional[GPT2Config] = None,
|
pretrained: Optional[str] = None,
|
||||||
lora_rank: int = 0,
|
config: Optional[GPT2Config] = None,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = "none",
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = GPT2Model.from_pretrained(pretrained)
|
model = GPT2Model.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -2,4 +2,4 @@ from .llama_actor import LlamaActor
|
|||||||
from .llama_critic import LlamaCritic
|
from .llama_critic import LlamaCritic
|
||||||
from .llama_rm import LlamaRM
|
from .llama_rm import LlamaRM
|
||||||
|
|
||||||
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
|
__all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"]
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
from transformers import LlamaConfig, LlamaForCausalLM
|
||||||
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
|
||||||
|
|
||||||
from ..base import Actor
|
from ..base import Actor
|
||||||
|
|
||||||
@ -18,13 +17,14 @@ class LlamaActor(Actor):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: Optional[str] = None,
|
self,
|
||||||
config: Optional[LlamaConfig] = None,
|
pretrained: Optional[str] = None,
|
||||||
checkpoint: bool = False,
|
config: Optional[LlamaConfig] = None,
|
||||||
lora_rank: int = 0,
|
checkpoint: bool = False,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = "none",
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
model = LlamaForCausalLM.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -17,13 +17,14 @@ class LlamaCritic(Critic):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: Optional[str] = None,
|
self,
|
||||||
config: Optional[LlamaConfig] = None,
|
pretrained: Optional[str] = None,
|
||||||
lora_rank: int = 0,
|
config: Optional[LlamaConfig] = None,
|
||||||
lora_train_bias: str = 'none',
|
lora_rank: int = 0,
|
||||||
**kwargs) -> None:
|
lora_train_bias: str = "none",
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = LlamaModel.from_pretrained(pretrained)
|
model = LlamaModel.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
|
from transformers import LlamaConfig, LlamaModel
|
||||||
|
|
||||||
from ..base import RewardModel
|
from ..base import RewardModel
|
||||||
|
|
||||||
@ -17,12 +17,13 @@ class LlamaRM(RewardModel):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: Optional[str] = None,
|
self,
|
||||||
config: Optional[LlamaConfig] = None,
|
pretrained: Optional[str] = None,
|
||||||
lora_rank: int = 0,
|
config: Optional[LlamaConfig] = None,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = "none",
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = LlamaModel.from_pretrained(pretrained)
|
model = LlamaModel.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -8,8 +8,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
|
|
||||||
class LoraLinear(lora.LoRALayer, nn.Module):
|
class LoraLinear(lora.LoRALayer, nn.Module):
|
||||||
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
|
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -17,16 +16,14 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|||||||
bias: Optional[nn.Parameter],
|
bias: Optional[nn.Parameter],
|
||||||
r: int = 0,
|
r: int = 0,
|
||||||
lora_alpha: int = 1,
|
lora_alpha: int = 1,
|
||||||
lora_dropout: float = 0.,
|
lora_dropout: float = 0.0,
|
||||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||||
merge_weights: bool = True,
|
merge_weights: bool = True,
|
||||||
):
|
):
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
lora.LoRALayer.__init__(self,
|
lora.LoRALayer.__init__(
|
||||||
r=r,
|
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
|
||||||
lora_alpha=lora_alpha,
|
)
|
||||||
lora_dropout=lora_dropout,
|
|
||||||
merge_weights=merge_weights)
|
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
|
||||||
@ -47,13 +44,12 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|||||||
self.weight.data = self.weight.data.T
|
self.weight.data = self.weight.data.T
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
if hasattr(self, 'lora_A'):
|
if hasattr(self, "lora_A"):
|
||||||
# Initialize A with the default values for nn.Linear and set B to zero.
|
# Initialize A with the default values for nn.Linear and set B to zero.
|
||||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||||
nn.init.zeros_(self.lora_B)
|
nn.init.zeros_(self.lora_B)
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
def train(self, mode: bool = True):
|
||||||
|
|
||||||
def T(w):
|
def T(w):
|
||||||
return w.T if self.fan_in_fan_out else w
|
return w.T if self.fan_in_fan_out else w
|
||||||
|
|
||||||
@ -71,7 +67,6 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|||||||
self.merged = False
|
self.merged = False
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
|
|
||||||
def T(w):
|
def T(w):
|
||||||
return w.T if self.fan_in_fan_out else w
|
return w.T if self.fan_in_fan_out else w
|
||||||
|
|
||||||
@ -80,12 +75,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|||||||
# Merge the weights and mark it
|
# Merge the weights and mark it
|
||||||
if self.r > 0:
|
if self.r > 0:
|
||||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||||
delattr(self, 'lora_A')
|
delattr(self, "lora_A")
|
||||||
delattr(self, 'lora_B')
|
delattr(self, "lora_B")
|
||||||
self.merged = True
|
self.merged = True
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
|
|
||||||
def T(w):
|
def T(w):
|
||||||
return w.T if self.fan_in_fan_out else w
|
return w.T if self.fan_in_fan_out else w
|
||||||
|
|
||||||
@ -99,7 +93,9 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
||||||
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
|
assert (
|
||||||
|
lora_rank <= linear.in_features
|
||||||
|
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
|
||||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
|
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
|
||||||
return lora_linear
|
return lora_linear
|
||||||
|
|
||||||
@ -112,7 +108,7 @@ def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
|||||||
_convert_to_lora_recursively(child, lora_rank)
|
_convert_to_lora_recursively(child, lora_rank)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
|
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
|
||||||
"""Convert a torch.nn.Module to a LoRA module.
|
"""Convert a torch.nn.Module to a LoRA module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -140,7 +136,7 @@ class LoRAModule(nn.Module):
|
|||||||
Defaults to 'none'.
|
Defaults to 'none'.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lora_rank = lora_rank
|
self.lora_rank = lora_rank
|
||||||
self.lora_train_bias = lora_train_bias
|
self.lora_train_bias = lora_train_bias
|
||||||
|
@ -31,11 +31,13 @@ class PolicyLoss(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_eps = clip_eps
|
self.clip_eps = clip_eps
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
log_probs: torch.Tensor,
|
self,
|
||||||
old_log_probs: torch.Tensor,
|
log_probs: torch.Tensor,
|
||||||
advantages: torch.Tensor,
|
old_log_probs: torch.Tensor,
|
||||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
advantages: torch.Tensor,
|
||||||
|
action_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
ratio = (log_probs - old_log_probs).exp()
|
ratio = (log_probs - old_log_probs).exp()
|
||||||
surr1 = ratio * advantages
|
surr1 = ratio * advantages
|
||||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||||
@ -55,14 +57,16 @@ class ValueLoss(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_eps = clip_eps
|
self.clip_eps = clip_eps
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
values: torch.Tensor,
|
self,
|
||||||
old_values: torch.Tensor,
|
values: torch.Tensor,
|
||||||
reward: torch.Tensor,
|
old_values: torch.Tensor,
|
||||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
reward: torch.Tensor,
|
||||||
|
action_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
||||||
surr1 = (values_clipped - reward)**2
|
surr1 = (values_clipped - reward) ** 2
|
||||||
surr2 = (values - reward)**2
|
surr2 = (values - reward) ** 2
|
||||||
loss = torch.max(surr1, surr2)
|
loss = torch.max(surr1, surr2)
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
return 0.5 * loss
|
return 0.5 * loss
|
||||||
|
@ -2,4 +2,4 @@ from .opt_actor import OPTActor
|
|||||||
from .opt_critic import OPTCritic
|
from .opt_critic import OPTCritic
|
||||||
from .opt_rm import OPTRM
|
from .opt_rm import OPTRM
|
||||||
|
|
||||||
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
|
__all__ = ["OPTActor", "OPTCritic", "OPTRM"]
|
||||||
|
@ -18,12 +18,14 @@ class OPTActor(Actor):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: Optional[str] = None,
|
self,
|
||||||
config: Optional[OPTConfig] = None,
|
pretrained: Optional[str] = None,
|
||||||
checkpoint: bool = False,
|
config: Optional[OPTConfig] = None,
|
||||||
lora_rank: int = 0,
|
checkpoint: bool = False,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = "none",
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = OPTForCausalLM.from_pretrained(pretrained)
|
model = OPTForCausalLM.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -18,12 +18,14 @@ class OPTCritic(Critic):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: Optional[str] = None,
|
self,
|
||||||
config: Optional[OPTConfig] = None,
|
pretrained: Optional[str] = None,
|
||||||
lora_rank: int = 0,
|
config: Optional[OPTConfig] = None,
|
||||||
lora_train_bias: str = 'none',
|
lora_rank: int = 0,
|
||||||
**kwargs) -> None:
|
lora_train_bias: str = "none",
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = OPTModel.from_pretrained(pretrained)
|
model = OPTModel.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -17,11 +17,13 @@ class OPTRM(RewardModel):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: Optional[str] = None,
|
self,
|
||||||
config: Optional[OPTConfig] = None,
|
pretrained: Optional[str] = None,
|
||||||
lora_rank: int = 0,
|
config: Optional[OPTConfig] = None,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = "none",
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = OPTModel.from_pretrained(pretrained)
|
model = OPTModel.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -4,9 +4,9 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def _compute_approx_kl(log_probs: torch.Tensor,
|
def _compute_approx_kl(
|
||||||
log_probs_base: torch.Tensor,
|
log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None
|
||||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute the approximate KL divergence between two distributions.
|
Compute the approximate KL divergence between two distributions.
|
||||||
Schulman blog: http://joschu.net/blog/kl-approx.html
|
Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||||
@ -26,11 +26,13 @@ def _compute_approx_kl(log_probs: torch.Tensor,
|
|||||||
return approx_kl
|
return approx_kl
|
||||||
|
|
||||||
|
|
||||||
def compute_reward(r: Union[torch.Tensor, float],
|
def compute_reward(
|
||||||
kl_coef: float,
|
r: Union[torch.Tensor, float],
|
||||||
log_probs: torch.Tensor,
|
kl_coef: float,
|
||||||
log_probs_base: torch.Tensor,
|
log_probs: torch.Tensor,
|
||||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
log_probs_base: torch.Tensor,
|
||||||
|
action_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
if kl_coef <= 0.0:
|
if kl_coef <= 0.0:
|
||||||
return r
|
return r
|
||||||
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
||||||
@ -55,7 +57,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Action log probs.
|
torch.Tensor: Action log probs.
|
||||||
"""
|
"""
|
||||||
logits = output['logits']
|
logits = output["logits"]
|
||||||
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||||
return log_probs[:, -num_actions:]
|
return log_probs[:, -num_actions:]
|
||||||
|
|
||||||
|
@ -2,6 +2,6 @@ from .llama_gptq import load_quant as llama_load_quant
|
|||||||
from .utils import low_resource_init
|
from .utils import low_resource_init
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'llama_load_quant',
|
"llama_load_quant",
|
||||||
'low_resource_init',
|
"low_resource_init",
|
||||||
]
|
]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from .loader import load_quant
|
from .loader import load_quant
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'load_quant',
|
"load_quant",
|
||||||
]
|
]
|
||||||
|
@ -11,14 +11,15 @@ def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int):
|
|||||||
|
|
||||||
# ignore lm head
|
# ignore lm head
|
||||||
layers = find_layers(model)
|
layers = find_layers(model)
|
||||||
for name in ['lm_head']:
|
for name in ["lm_head"]:
|
||||||
if name in layers:
|
if name in layers:
|
||||||
del layers[name]
|
del layers[name]
|
||||||
|
|
||||||
make_quant(model, layers, wbits, groupsize)
|
make_quant(model, layers, wbits, groupsize)
|
||||||
|
|
||||||
if checkpoint.endswith('.safetensors'):
|
if checkpoint.endswith(".safetensors"):
|
||||||
from safetensors.torch import load_file as safe_load
|
from safetensors.torch import load_file as safe_load
|
||||||
|
|
||||||
model.load_state_dict(safe_load(checkpoint))
|
model.load_state_dict(safe_load(checkpoint))
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch.load(checkpoint))
|
model.load_state_dict(torch.load(checkpoint))
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
|
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
|
||||||
if type(module) in layers:
|
if type(module) in layers:
|
||||||
return {name: module}
|
return {name: module}
|
||||||
res = {}
|
res = {}
|
||||||
for name1, child in module.named_children():
|
for name1, child in module.named_children():
|
||||||
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
|
res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1))
|
||||||
return res
|
return res
|
||||||
|
@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq):
|
|||||||
|
|
||||||
|
|
||||||
class Quantizer(nn.Module):
|
class Quantizer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, shape=1):
|
def __init__(self, shape=1):
|
||||||
super(Quantizer, self).__init__()
|
super(Quantizer, self).__init__()
|
||||||
self.register_buffer('maxq', torch.tensor(0))
|
self.register_buffer("maxq", torch.tensor(0))
|
||||||
self.register_buffer('scale', torch.zeros(shape))
|
self.register_buffer("scale", torch.zeros(shape))
|
||||||
self.register_buffer('zero', torch.zeros(shape))
|
self.register_buffer("zero", torch.zeros(shape))
|
||||||
|
|
||||||
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
|
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8):
|
||||||
self.maxq = torch.tensor(2**bits - 1)
|
self.maxq = torch.tensor(2**bits - 1)
|
||||||
self.perchannel = perchannel
|
self.perchannel = perchannel
|
||||||
self.sym = sym
|
self.sym = sym
|
||||||
@ -68,7 +67,7 @@ class Quantizer(nn.Module):
|
|||||||
self.zero = torch.round(-xmin / self.scale)
|
self.zero = torch.round(-xmin / self.scale)
|
||||||
|
|
||||||
if self.mse:
|
if self.mse:
|
||||||
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
best = torch.full([x.shape[0]], float("inf"), device=dev)
|
||||||
for i in range(int(self.maxshrink * self.grid)):
|
for i in range(int(self.maxshrink * self.grid)):
|
||||||
p = 1 - i / self.grid
|
p = 1 - i / self.grid
|
||||||
xmin1 = p * xmin
|
xmin1 = p * xmin
|
||||||
@ -123,13 +122,12 @@ class Quantizer(nn.Module):
|
|||||||
try:
|
try:
|
||||||
import quant_cuda
|
import quant_cuda
|
||||||
except:
|
except:
|
||||||
print('CUDA extension not installed.')
|
print("CUDA extension not installed.")
|
||||||
|
|
||||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||||
|
|
||||||
|
|
||||||
class QuantLinear(nn.Module):
|
class QuantLinear(nn.Module):
|
||||||
|
|
||||||
def __init__(self, bits, groupsize, infeatures, outfeatures):
|
def __init__(self, bits, groupsize, infeatures, outfeatures):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if bits not in [2, 3, 4, 8]:
|
if bits not in [2, 3, 4, 8]:
|
||||||
@ -142,11 +140,11 @@ class QuantLinear(nn.Module):
|
|||||||
groupsize = groupsize if groupsize != -1 else infeatures
|
groupsize = groupsize if groupsize != -1 else infeatures
|
||||||
self.groupsize = groupsize
|
self.groupsize = groupsize
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
|
"qzeros", torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
|
||||||
dtype=torch.int))
|
)
|
||||||
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
|
self.register_buffer("scales", torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
|
||||||
self.register_buffer('bias', torch.zeros(outfeatures))
|
self.register_buffer("bias", torch.zeros(outfeatures))
|
||||||
self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
|
self.register_buffer("qweight", torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
|
||||||
self._initialized_quant_state = False
|
self._initialized_quant_state = False
|
||||||
|
|
||||||
def pack(self, linear, scales, zeros):
|
def pack(self, linear, scales, zeros):
|
||||||
@ -161,8 +159,10 @@ class QuantLinear(nn.Module):
|
|||||||
for idx in range(self.infeatures):
|
for idx in range(self.infeatures):
|
||||||
g_idx = idx // self.groupsize
|
g_idx = idx // self.groupsize
|
||||||
intweight.append(
|
intweight.append(
|
||||||
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
|
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[
|
||||||
None])
|
:, None
|
||||||
|
]
|
||||||
|
)
|
||||||
intweight = torch.cat(intweight, dim=1)
|
intweight = torch.cat(intweight, dim=1)
|
||||||
intweight = intweight.t().contiguous()
|
intweight = intweight.t().contiguous()
|
||||||
intweight = intweight.numpy().astype(np.uint32)
|
intweight = intweight.numpy().astype(np.uint32)
|
||||||
@ -271,13 +271,13 @@ class QuantLinear(nn.Module):
|
|||||||
return y.reshape(outshape)
|
return y.reshape(outshape)
|
||||||
|
|
||||||
|
|
||||||
def make_quant(module, names, bits, groupsize, name=''):
|
def make_quant(module, names, bits, groupsize, name=""):
|
||||||
if isinstance(module, QuantLinear):
|
if isinstance(module, QuantLinear):
|
||||||
return
|
return
|
||||||
for attr in dir(module):
|
for attr in dir(module):
|
||||||
tmp = getattr(module, attr)
|
tmp = getattr(module, attr)
|
||||||
name1 = name + '.' + attr if name != '' else attr
|
name1 = name + "." + attr if name != "" else attr
|
||||||
if name1 in names:
|
if name1 in names:
|
||||||
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
|
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
|
||||||
for name1, child in module.named_children():
|
for name1, child in module.named_children():
|
||||||
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
|
make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1)
|
||||||
|
@ -9,8 +9,7 @@ def _noop(*args, **kwargs):
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def low_resource_init():
|
def low_resource_init():
|
||||||
"""This context manager disables weight initialization and sets the default float dtype to half.
|
"""This context manager disables weight initialization and sets the default float dtype to half."""
|
||||||
"""
|
|
||||||
old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
|
old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
|
||||||
old_uniform_ = torch.nn.init.uniform_
|
old_uniform_ = torch.nn.init.uniform_
|
||||||
old_normal_ = torch.nn.init.normal_
|
old_normal_ = torch.nn.init.normal_
|
||||||
|
@ -5,7 +5,7 @@ from coati.experience_maker import Experience
|
|||||||
|
|
||||||
class TrainerCallback(ABC):
|
class TrainerCallback(ABC):
|
||||||
"""
|
"""
|
||||||
Base callback class. It defines the interface for callbacks.
|
Base callback class. It defines the interface for callbacks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def on_fit_start(self) -> None:
|
def on_fit_start(self) -> None:
|
||||||
@ -40,7 +40,6 @@ class TrainerCallback(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class MakerCallback(ABC):
|
class MakerCallback(ABC):
|
||||||
|
|
||||||
def on_loop_start(self) -> None:
|
def on_loop_start(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -30,10 +30,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
|||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.start_time: Optional[float] = None
|
self.start_time: Optional[float] = None
|
||||||
self.duration: float = 0.
|
self.duration: float = 0.0
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
self.start_time = time()
|
self.start_time = time()
|
||||||
@ -42,13 +41,13 @@ class Timer:
|
|||||||
self.duration += time() - self.start_time
|
self.duration += time() - self.start_time
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.duration = 0.
|
self.duration = 0.0
|
||||||
|
|
||||||
|
|
||||||
class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
||||||
|
def __init__(
|
||||||
def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int,
|
self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int
|
||||||
reward_model_num_params: int) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.world_size = get_world_size()
|
self.world_size = get_world_size()
|
||||||
self.actor_num_params = actor_num_params
|
self.actor_num_params = actor_num_params
|
||||||
@ -63,7 +62,7 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
|||||||
self.make_experience_flop: int = 0
|
self.make_experience_flop: int = 0
|
||||||
|
|
||||||
print_rank_0(
|
print_rank_0(
|
||||||
f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}'
|
f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_make_experience_start(self) -> None:
|
def on_make_experience_start(self) -> None:
|
||||||
@ -110,27 +109,29 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
|||||||
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
|
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
|
||||||
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
||||||
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
|
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||||
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \
|
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / (
|
||||||
(self.total_samples * self.world_size)
|
self.total_samples * self.world_size
|
||||||
|
)
|
||||||
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
|
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||||
|
|
||||||
print_rank_0(
|
print_rank_0(
|
||||||
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
|
"Making Experience Performance Summary:\n"
|
||||||
+ f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
|
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
|
||||||
+ f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
|
+ f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n"
|
||||||
+ f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
|
||||||
|
+ f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||||
+ f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
+ f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TrainerPerformanceEvaluator(TrainerCallback):
|
class TrainerPerformanceEvaluator(TrainerCallback):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self,
|
||||||
actor_num_params: int,
|
actor_num_params: int,
|
||||||
critic_num_params: int,
|
critic_num_params: int,
|
||||||
enable_grad_checkpoint: bool = False,
|
enable_grad_checkpoint: bool = False,
|
||||||
ignore_first_episodes: int = 1) -> None:
|
ignore_first_episodes: int = 1,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.world_size = get_world_size()
|
self.world_size = get_world_size()
|
||||||
self.actor_num_params = actor_num_params
|
self.actor_num_params = actor_num_params
|
||||||
@ -146,7 +147,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
|
|||||||
self.learn_flop: int = 0
|
self.learn_flop: int = 0
|
||||||
|
|
||||||
print_rank_0(
|
print_rank_0(
|
||||||
f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}'
|
f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_episode_start(self, episodes: int) -> None:
|
def on_episode_start(self, episodes: int) -> None:
|
||||||
@ -191,7 +192,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
|
|||||||
|
|
||||||
def on_fit_end(self) -> None:
|
def on_fit_end(self) -> None:
|
||||||
if self.total_samples == 0:
|
if self.total_samples == 0:
|
||||||
print_rank_0('No samples are collected, skip trainer performance evaluation')
|
print_rank_0("No samples are collected, skip trainer performance evaluation")
|
||||||
return
|
return
|
||||||
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
|
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
|
||||||
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
|
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
|
||||||
@ -204,9 +205,10 @@ class TrainerPerformanceEvaluator(TrainerCallback):
|
|||||||
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
|
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||||
|
|
||||||
print_rank_0(
|
print_rank_0(
|
||||||
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
|
"Learning Performance Summary:\n"
|
||||||
+ f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
|
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
|
||||||
+ f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
+ f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n"
|
||||||
|
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
|
||||||
+ f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
+ f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||||
|
+ f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||||
)
|
)
|
||||||
|
@ -1,20 +1,15 @@
|
|||||||
import asyncio
|
from typing import List
|
||||||
import copy
|
|
||||||
import random
|
|
||||||
from threading import Lock
|
|
||||||
from typing import Any, List
|
|
||||||
|
|
||||||
import ray
|
|
||||||
import torch
|
import torch
|
||||||
from coati.experience_buffer import ExperienceBuffer
|
|
||||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||||
from coati.experience_maker.base import Experience
|
from coati.experience_maker.base import Experience
|
||||||
|
|
||||||
# from torch.multiprocessing import Queue
|
# from torch.multiprocessing import Queue
|
||||||
from ray.util.queue import Queue
|
from ray.util.queue import Queue
|
||||||
|
|
||||||
|
|
||||||
class DetachedReplayBuffer:
|
class DetachedReplayBuffer:
|
||||||
'''
|
"""
|
||||||
Detached replay buffer. Share Experience across workers on the same node.
|
Detached replay buffer. Share Experience across workers on the same node.
|
||||||
Therefore, a trainer node is expected to have only one instance.
|
Therefore, a trainer node is expected to have only one instance.
|
||||||
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
|
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
|
||||||
@ -24,7 +19,7 @@ class DetachedReplayBuffer:
|
|||||||
tp_world_size: Number of workers in the same tp group
|
tp_world_size: Number of workers in the same tp group
|
||||||
limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
|
limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
|
||||||
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
|
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
|
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
|
||||||
self.sample_batch_size = sample_batch_size
|
self.sample_batch_size = sample_batch_size
|
||||||
@ -34,23 +29,23 @@ class DetachedReplayBuffer:
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def append(self, experience: Experience) -> None:
|
def append(self, experience: Experience) -> None:
|
||||||
'''
|
"""
|
||||||
Expected to be called remotely.
|
Expected to be called remotely.
|
||||||
'''
|
"""
|
||||||
items = split_experience_batch(experience)
|
items = split_experience_batch(experience)
|
||||||
self.extend(items)
|
self.extend(items)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def extend(self, items: List[BufferItem]) -> None:
|
def extend(self, items: List[BufferItem]) -> None:
|
||||||
'''
|
"""
|
||||||
Expected to be called remotely.
|
Expected to be called remotely.
|
||||||
'''
|
"""
|
||||||
self.batch_collector.extend(items)
|
self.batch_collector.extend(items)
|
||||||
while len(self.batch_collector) >= self.sample_batch_size:
|
while len(self.batch_collector) >= self.sample_batch_size:
|
||||||
items = self.batch_collector[:self.sample_batch_size]
|
items = self.batch_collector[: self.sample_batch_size]
|
||||||
experience = make_experience_batch(items)
|
experience = make_experience_batch(items)
|
||||||
self.items.put(experience, block=True)
|
self.items.put(experience, block=True)
|
||||||
self.batch_collector = self.batch_collector[self.sample_batch_size:]
|
self.batch_collector = self.batch_collector[self.sample_batch_size :]
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
# self.items.close()
|
# self.items.close()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
@ -15,7 +15,7 @@ from .utils import is_rank_0
|
|||||||
|
|
||||||
|
|
||||||
class DetachedTrainer(ABC):
|
class DetachedTrainer(ABC):
|
||||||
'''
|
"""
|
||||||
Base class for detached rlhf trainers.
|
Base class for detached rlhf trainers.
|
||||||
'detach' means that the experience maker is detached compared to a normal Trainer.
|
'detach' means that the experience maker is detached compared to a normal Trainer.
|
||||||
Please set name attribute during init:
|
Please set name attribute during init:
|
||||||
@ -28,15 +28,17 @@ class DetachedTrainer(ABC):
|
|||||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
experience_maker_holder_name_list: List[str],
|
self,
|
||||||
train_batch_size: int = 8,
|
experience_maker_holder_name_list: List[str],
|
||||||
buffer_limit: int = 0,
|
train_batch_size: int = 8,
|
||||||
dataloader_pin_memory: bool = True,
|
buffer_limit: int = 0,
|
||||||
callbacks: List[TrainerCallback] = [],
|
dataloader_pin_memory: bool = True,
|
||||||
debug: bool = False) -> None:
|
callbacks: List[TrainerCallback] = [],
|
||||||
|
debug: bool = False,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
|
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
|
||||||
self.dataloader_pin_memory = dataloader_pin_memory
|
self.dataloader_pin_memory = dataloader_pin_memory
|
||||||
@ -67,18 +69,16 @@ class DetachedTrainer(ABC):
|
|||||||
def _learn(self, update_steps: int, train_epochs: int) -> None:
|
def _learn(self, update_steps: int, train_epochs: int) -> None:
|
||||||
data = []
|
data = []
|
||||||
# warmup
|
# warmup
|
||||||
pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
|
pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
|
||||||
self._on_epoch_start(0)
|
self._on_epoch_start(0)
|
||||||
self._learn_epoch(pbar, data)
|
self._learn_epoch(pbar, data)
|
||||||
self._on_epoch_end(0)
|
self._on_epoch_end(0)
|
||||||
# item is already a batch
|
# item is already a batch
|
||||||
dataloader = DataLoader(data,
|
dataloader = DataLoader(
|
||||||
batch_size=1,
|
data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
|
||||||
shuffle=True,
|
)
|
||||||
pin_memory=self.dataloader_pin_memory,
|
|
||||||
collate_fn=lambda x: x[0])
|
|
||||||
for epoch in range(1, train_epochs):
|
for epoch in range(1, train_epochs):
|
||||||
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
|
pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
|
||||||
self._on_epoch_start(epoch)
|
self._on_epoch_start(epoch)
|
||||||
self._learn_epoch(pbar, data)
|
self._learn_epoch(pbar, data)
|
||||||
self._on_epoch_end(epoch)
|
self._on_epoch_end(epoch)
|
||||||
@ -104,7 +104,7 @@ class DetachedTrainer(ABC):
|
|||||||
|
|
||||||
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
|
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
|
||||||
self._on_fit_start()
|
self._on_fit_start()
|
||||||
for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
|
for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
|
||||||
self._on_episode_start(i)
|
self._on_episode_start(i)
|
||||||
self._learn(update_steps, train_epochs)
|
self._learn(update_steps, train_epochs)
|
||||||
self._on_update_start()
|
self._on_update_start()
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Tuple
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
from coati.experience_maker import Experience
|
||||||
from coati.models.base import Actor, Critic
|
from coati.models.base import Actor, Critic
|
||||||
from coati.models.loss import PolicyLoss, ValueLoss
|
from coati.models.loss import PolicyLoss, ValueLoss
|
||||||
from coati.trainer.callbacks import Callback
|
from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
|
||||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
|
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
@ -14,27 +13,14 @@ from colossalai.nn.optimizer import HybridAdam
|
|||||||
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
|
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
|
||||||
from .detached_trainer_base import DetachedTrainer
|
from .detached_trainer_base import DetachedTrainer
|
||||||
from .lora_constructor import LoRAConstructor
|
from .lora_constructor import LoRAConstructor
|
||||||
from .utils import (
|
from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to
|
||||||
get_actor_from_args,
|
|
||||||
get_critic_from_args,
|
|
||||||
get_model_numel,
|
@ray.remote(
|
||||||
get_rank,
|
concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}
|
||||||
get_strategy_from_args,
|
|
||||||
is_rank_0,
|
|
||||||
set_dist_env,
|
|
||||||
state_dict_to,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(concurrency_groups={
|
|
||||||
"buffer_length": 1,
|
|
||||||
"buffer_append": 1,
|
|
||||||
"buffer_sample": 1,
|
|
||||||
"model_io": 1,
|
|
||||||
"compute": 1
|
|
||||||
})
|
|
||||||
class DetachedPPOTrainer(DetachedTrainer):
|
class DetachedPPOTrainer(DetachedTrainer):
|
||||||
'''
|
"""
|
||||||
Detached Trainer for PPO algorithm
|
Detached Trainer for PPO algorithm
|
||||||
Args:
|
Args:
|
||||||
strategy (Strategy): the strategy to use for training
|
strategy (Strategy): the strategy to use for training
|
||||||
@ -52,7 +38,7 @@ class DetachedPPOTrainer(DetachedTrainer):
|
|||||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -92,21 +78,24 @@ class DetachedPPOTrainer(DetachedTrainer):
|
|||||||
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
|
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
|
||||||
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
|
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
|
||||||
|
|
||||||
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
|
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare(
|
||||||
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
|
(self.actor, self.actor_optim), (self.critic, self.critic_optim)
|
||||||
|
)
|
||||||
|
|
||||||
# configure trainer
|
# configure trainer
|
||||||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||||
self.critic_loss_fn = ValueLoss(value_clip)
|
self.critic_loss_fn = ValueLoss(value_clip)
|
||||||
|
|
||||||
super().__init__(experience_maker_holder_name_list,
|
super().__init__(
|
||||||
train_batch_size=train_batch_size,
|
experience_maker_holder_name_list,
|
||||||
buffer_limit=buffer_limit,
|
train_batch_size=train_batch_size,
|
||||||
dataloader_pin_memory=dataloader_pin_memory,
|
buffer_limit=buffer_limit,
|
||||||
callbacks=callbacks,
|
dataloader_pin_memory=dataloader_pin_memory,
|
||||||
debug=debug)
|
callbacks=callbacks,
|
||||||
|
debug=debug,
|
||||||
|
)
|
||||||
if self._debug:
|
if self._debug:
|
||||||
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
|
print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}")
|
||||||
|
|
||||||
self._update_lora_weights = update_lora_weights
|
self._update_lora_weights = update_lora_weights
|
||||||
|
|
||||||
@ -115,7 +104,7 @@ class DetachedPPOTrainer(DetachedTrainer):
|
|||||||
def _update_remote_makers(self, fully_update: bool = False, **config):
|
def _update_remote_makers(self, fully_update: bool = False, **config):
|
||||||
# TODO: balance duties
|
# TODO: balance duties
|
||||||
if not fully_update:
|
if not fully_update:
|
||||||
config['requires_grad_only'] = True
|
config["requires_grad_only"] = True
|
||||||
self.update_target_holder_list()
|
self.update_target_holder_list()
|
||||||
# mark start, ensure order
|
# mark start, ensure order
|
||||||
tasks = []
|
tasks = []
|
||||||
@ -131,7 +120,9 @@ class DetachedPPOTrainer(DetachedTrainer):
|
|||||||
target_holder.update_experience_maker.remote(
|
target_holder.update_experience_maker.remote(
|
||||||
new_actor_state_dict=state_dict_shard,
|
new_actor_state_dict=state_dict_shard,
|
||||||
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
|
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
|
||||||
fully_update=fully_update))
|
fully_update=fully_update,
|
||||||
|
)
|
||||||
|
)
|
||||||
# sending loop
|
# sending loop
|
||||||
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
|
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
|
||||||
for target_holder in self.target_holder_list:
|
for target_holder in self.target_holder_list:
|
||||||
@ -139,7 +130,9 @@ class DetachedPPOTrainer(DetachedTrainer):
|
|||||||
target_holder.update_experience_maker.remote(
|
target_holder.update_experience_maker.remote(
|
||||||
new_critic_state_dict=state_dict_shard,
|
new_critic_state_dict=state_dict_shard,
|
||||||
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
|
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
|
||||||
fully_update=fully_update))
|
fully_update=fully_update,
|
||||||
|
)
|
||||||
|
)
|
||||||
ray.get(tasks)
|
ray.get(tasks)
|
||||||
# mark end
|
# mark end
|
||||||
for target_holder in self.target_holder_list:
|
for target_holder in self.target_holder_list:
|
||||||
@ -152,26 +145,24 @@ class DetachedPPOTrainer(DetachedTrainer):
|
|||||||
|
|
||||||
num_actions = experience.action_mask.size(1)
|
num_actions = experience.action_mask.size(1)
|
||||||
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
||||||
actor_loss = self.actor_loss_fn(action_log_probs,
|
actor_loss = self.actor_loss_fn(
|
||||||
experience.action_log_probs,
|
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||||
experience.advantages,
|
)
|
||||||
action_mask=experience.action_mask)
|
|
||||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||||
self.strategy.optimizer_step(self.actor_optim)
|
self.strategy.optimizer_step(self.actor_optim)
|
||||||
self.actor_optim.zero_grad()
|
self.actor_optim.zero_grad()
|
||||||
|
|
||||||
values = self.critic(experience.sequences,
|
values = self.critic(
|
||||||
action_mask=experience.action_mask,
|
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
|
||||||
attention_mask=experience.attention_mask)
|
)
|
||||||
critic_loss = self.critic_loss_fn(values,
|
critic_loss = self.critic_loss_fn(
|
||||||
experience.values,
|
values, experience.values, experience.reward, action_mask=experience.action_mask
|
||||||
experience.reward,
|
)
|
||||||
action_mask=experience.action_mask)
|
|
||||||
|
|
||||||
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
||||||
self.strategy.optimizer_step(self.critic_optim)
|
self.strategy.optimizer_step(self.critic_optim)
|
||||||
self.critic_optim.zero_grad()
|
self.critic_optim.zero_grad()
|
||||||
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()}
|
||||||
|
|
||||||
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
|
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
|
||||||
self.strategy.save_model(self.actor, path, only_rank0)
|
self.strategy.save_model(self.actor, path, only_rank0)
|
||||||
|
@ -1,53 +1,49 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import tracemalloc
|
import tracemalloc
|
||||||
from copy import deepcopy
|
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from coati.experience_buffer.utils import split_experience_batch
|
||||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||||
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
|
|
||||||
from coati.models.base import Actor, Critic, RewardModel
|
from coati.models.base import Actor, Critic, RewardModel
|
||||||
from coati.trainer.callbacks import Callback
|
|
||||||
from coati.trainer.strategies import Strategy
|
from coati.trainer.strategies import Strategy
|
||||||
from coati.trainer.strategies.sampler import DistributedSampler
|
|
||||||
from ray.exceptions import GetTimeoutError
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
|
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
|
||||||
from .lora_constructor import LoRAConstructor
|
from .lora_constructor import LoRAConstructor
|
||||||
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
|
from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
|
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
|
||||||
class ExperienceMakerHolder:
|
class ExperienceMakerHolder:
|
||||||
'''
|
"""
|
||||||
Args:
|
Args:
|
||||||
detached_trainer_name_list: str list to get ray actor handles
|
detached_trainer_name_list: str list to get ray actor handles
|
||||||
strategy:
|
strategy:
|
||||||
kl_coef: the coefficient of kl divergence loss
|
kl_coef: the coefficient of kl divergence loss
|
||||||
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
|
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
detached_trainer_name_list: List[str],
|
detached_trainer_name_list: List[str],
|
||||||
strategy_fn: Callable[[], Strategy],
|
strategy_fn: Callable[[], Strategy],
|
||||||
# a function returns (actor, critic, reward_model, initial_model)
|
# a function returns (actor, critic, reward_model, initial_model)
|
||||||
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
|
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
|
||||||
env_info: Dict[str, str] = None,
|
env_info: Dict[str, str] = None,
|
||||||
sync_models_from_trainers: bool = False,
|
sync_models_from_trainers: bool = False,
|
||||||
buffer_cpu_offload: bool = True,
|
buffer_cpu_offload: bool = True,
|
||||||
kl_coef: float = 0.1,
|
kl_coef: float = 0.1,
|
||||||
callbacks: List[MakerCallback] = [],
|
callbacks: List[MakerCallback] = [],
|
||||||
eval_performance: bool = False,
|
eval_performance: bool = False,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
update_lora_weights: bool = False,
|
update_lora_weights: bool = False,
|
||||||
**generate_kwargs):
|
**generate_kwargs,
|
||||||
|
):
|
||||||
# set environment variables
|
# set environment variables
|
||||||
if env_info:
|
if env_info:
|
||||||
set_dist_env(env_info=env_info)
|
set_dist_env(env_info=env_info)
|
||||||
@ -66,8 +62,9 @@ class ExperienceMakerHolder:
|
|||||||
critic_numel = get_model_numel(critic)
|
critic_numel = get_model_numel(critic)
|
||||||
initial_model_numel = get_model_numel(initial_model)
|
initial_model_numel = get_model_numel(initial_model)
|
||||||
reward_model_numel = get_model_numel(reward_model)
|
reward_model_numel = get_model_numel(reward_model)
|
||||||
evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel,
|
evaluator = ExperienceMakerPerformanceEvaluator(
|
||||||
reward_model_numel)
|
actor_numel, critic_numel, initial_model_numel, reward_model_numel
|
||||||
|
)
|
||||||
callbacks = callbacks + [evaluator]
|
callbacks = callbacks + [evaluator]
|
||||||
|
|
||||||
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
|
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
|
||||||
@ -89,9 +86,9 @@ class ExperienceMakerHolder:
|
|||||||
self._target_idx = 0
|
self._target_idx = 0
|
||||||
|
|
||||||
if self._debug:
|
if self._debug:
|
||||||
print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}')
|
print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
|
||||||
if not self._is_fully_initialized:
|
if not self._is_fully_initialized:
|
||||||
print(f'[maker{get_rank()}] Waiting for INIT')
|
print(f"[maker{get_rank()}] Waiting for INIT")
|
||||||
|
|
||||||
def _get_ready(self):
|
def _get_ready(self):
|
||||||
while not self._fully_initialized():
|
while not self._fully_initialized():
|
||||||
@ -136,7 +133,7 @@ class ExperienceMakerHolder:
|
|||||||
self._on_make_experience_end(experience)
|
self._on_make_experience_end(experience)
|
||||||
self._on_send_start()
|
self._on_send_start()
|
||||||
if self.buffer_cpu_offload:
|
if self.buffer_cpu_offload:
|
||||||
experience.to_device('cpu')
|
experience.to_device("cpu")
|
||||||
self._send_items(experience)
|
self._send_items(experience)
|
||||||
self._on_send_end()
|
self._on_send_end()
|
||||||
self._on_batch_end()
|
self._on_batch_end()
|
||||||
@ -155,7 +152,7 @@ class ExperienceMakerHolder:
|
|||||||
if num_steps > 0:
|
if num_steps > 0:
|
||||||
# ignore num epochs
|
# ignore num epochs
|
||||||
it = iter(dataloader)
|
it = iter(dataloader)
|
||||||
for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()):
|
for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
|
||||||
try:
|
try:
|
||||||
batch = next(it)
|
batch = next(it)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
@ -163,7 +160,7 @@ class ExperienceMakerHolder:
|
|||||||
batch = next(it)
|
batch = next(it)
|
||||||
self._inference_step(batch)
|
self._inference_step(batch)
|
||||||
else:
|
else:
|
||||||
with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar:
|
with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
|
||||||
for _ in range(num_epochs):
|
for _ in range(num_epochs):
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
self._inference_step(batch)
|
self._inference_step(batch)
|
||||||
@ -171,22 +168,24 @@ class ExperienceMakerHolder:
|
|||||||
self._on_loop_end()
|
self._on_loop_end()
|
||||||
|
|
||||||
@ray.method(concurrency_group="model_io")
|
@ray.method(concurrency_group="model_io")
|
||||||
def update_experience_maker(self,
|
def update_experience_maker(
|
||||||
new_actor_state_dict: Dict[str, Any] = None,
|
self,
|
||||||
new_actor_lora_config_dict: Dict[str, Any] = None,
|
new_actor_state_dict: Dict[str, Any] = None,
|
||||||
new_critic_state_dict: Dict[str, Any] = None,
|
new_actor_lora_config_dict: Dict[str, Any] = None,
|
||||||
new_critic_lora_config_dict: Dict[str, Any] = None,
|
new_critic_state_dict: Dict[str, Any] = None,
|
||||||
fully_update: bool = False,
|
new_critic_lora_config_dict: Dict[str, Any] = None,
|
||||||
chunk_start: bool = None,
|
fully_update: bool = False,
|
||||||
chunk_end: bool = None):
|
chunk_start: bool = None,
|
||||||
'''
|
chunk_end: bool = None,
|
||||||
called by trainer
|
):
|
||||||
chunk_start: Set True at the first call. Before sending state_dict calls
|
"""
|
||||||
chunk_end: Set True at the last call. After sending state_dict calls.
|
called by trainer
|
||||||
fully_update: Set True if you want to sync models when initializing
|
chunk_start: Set True at the first call. Before sending state_dict calls
|
||||||
|
chunk_end: Set True at the last call. After sending state_dict calls.
|
||||||
|
fully_update: Set True if you want to sync models when initializing
|
||||||
|
|
||||||
TODO: load_state_dict integrate with model-sharding strategy
|
TODO: load_state_dict integrate with model-sharding strategy
|
||||||
'''
|
"""
|
||||||
_watch_memory = self._debug
|
_watch_memory = self._debug
|
||||||
if chunk_start:
|
if chunk_start:
|
||||||
if self._debug:
|
if self._debug:
|
||||||
@ -202,18 +201,22 @@ class ExperienceMakerHolder:
|
|||||||
else:
|
else:
|
||||||
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
|
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
|
||||||
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
|
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
|
||||||
new_actor_state_dict, new_actor_lora_config_dict)
|
new_actor_state_dict, new_actor_lora_config_dict
|
||||||
|
)
|
||||||
self.actor_lora_constructor.load_state_dict_increase(
|
self.actor_lora_constructor.load_state_dict_increase(
|
||||||
self.experience_maker.actor.model, state_dict_increase)
|
self.experience_maker.actor.model, state_dict_increase
|
||||||
|
)
|
||||||
if new_critic_state_dict is not None:
|
if new_critic_state_dict is not None:
|
||||||
if not self._update_lora_weights or fully_update:
|
if not self._update_lora_weights or fully_update:
|
||||||
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
|
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
|
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
|
||||||
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
|
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
|
||||||
new_critic_state_dict, new_critic_lora_config_dict)
|
new_critic_state_dict, new_critic_lora_config_dict
|
||||||
|
)
|
||||||
self.critic_lora_constructor.load_state_dict_increase(
|
self.critic_lora_constructor.load_state_dict_increase(
|
||||||
self.experience_maker.critic, state_dict_increase)
|
self.experience_maker.critic, state_dict_increase
|
||||||
|
)
|
||||||
|
|
||||||
# the lock must be released after both actor and critic being updated
|
# the lock must be released after both actor and critic being updated
|
||||||
if chunk_end:
|
if chunk_end:
|
||||||
@ -262,10 +265,10 @@ def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
|
|||||||
origin_model = actor.model
|
origin_model = actor.model
|
||||||
new_kwargs = {**generate_kwargs}
|
new_kwargs = {**generate_kwargs}
|
||||||
# use huggingface models method directly
|
# use huggingface models method directly
|
||||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
|
||||||
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
|
||||||
|
|
||||||
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
|
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
|
||||||
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
|
new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
|
||||||
|
|
||||||
return new_kwargs
|
return new_kwargs
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from coati.models.lora import LoraLinear
|
from coati.models.lora import LoraLinear
|
||||||
from loralib.layers import LoRALayer
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -17,7 +15,7 @@ class LoRAConfig:
|
|||||||
|
|
||||||
|
|
||||||
class LoRAConstructor:
|
class LoRAConstructor:
|
||||||
'''
|
"""
|
||||||
Tools for reconstructing a model from a remote LoRA model.
|
Tools for reconstructing a model from a remote LoRA model.
|
||||||
(Transferring only LoRA data costs much less!)
|
(Transferring only LoRA data costs much less!)
|
||||||
Usage:
|
Usage:
|
||||||
@ -36,7 +34,7 @@ class LoRAConstructor:
|
|||||||
Step 5 (Receiver):
|
Step 5 (Receiver):
|
||||||
load_state_dict_increase()
|
load_state_dict_increase()
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.lora_config_dict = None
|
self.lora_config_dict = None
|
||||||
@ -45,10 +43,10 @@ class LoRAConstructor:
|
|||||||
self.lora_config_dict = lora_config_dict
|
self.lora_config_dict = lora_config_dict
|
||||||
|
|
||||||
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
|
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
|
||||||
'''
|
"""
|
||||||
xxx.lora_A, xxx.lora_B -->> xxx.weight
|
xxx.lora_A, xxx.lora_B -->> xxx.weight
|
||||||
Warning: the xxx.weight here is the increment actually.
|
Warning: the xxx.weight here is the increment actually.
|
||||||
'''
|
"""
|
||||||
if lora_config_dict is not None:
|
if lora_config_dict is not None:
|
||||||
self.register_lora_config(lora_config_dict)
|
self.register_lora_config(lora_config_dict)
|
||||||
|
|
||||||
@ -56,24 +54,25 @@ class LoRAConstructor:
|
|||||||
config_iter = iter(self.lora_config_dict.items())
|
config_iter = iter(self.lora_config_dict.items())
|
||||||
lora_A, lora_B, layer_prefix = None, None, None
|
lora_A, lora_B, layer_prefix = None, None, None
|
||||||
for k, v in state_dict_lora.items():
|
for k, v in state_dict_lora.items():
|
||||||
if k.rpartition('.')[-1] == 'lora_A':
|
if k.rpartition(".")[-1] == "lora_A":
|
||||||
lora_A = v
|
lora_A = v
|
||||||
layer_prefix = k.rpartition('.')[0]
|
layer_prefix = k.rpartition(".")[0]
|
||||||
elif k.rpartition('.')[-1] == 'lora_B':
|
elif k.rpartition(".")[-1] == "lora_B":
|
||||||
assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair"
|
assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair"
|
||||||
layer_prefix_2, config = next(config_iter)
|
layer_prefix_2, config = next(config_iter)
|
||||||
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
|
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
|
||||||
lora_B = v
|
lora_B = v
|
||||||
weight_data_increase = self._compute(lora_A, lora_B, config)
|
weight_data_increase = self._compute(lora_A, lora_B, config)
|
||||||
state_dict_increase[layer_prefix + '.weight'] = weight_data_increase
|
state_dict_increase[layer_prefix + ".weight"] = weight_data_increase
|
||||||
lora_A, lora_B, layer_prefix = None, None, None
|
lora_A, lora_B, layer_prefix = None, None, None
|
||||||
else:
|
else:
|
||||||
raise ValueError('unexpected key')
|
raise ValueError("unexpected key")
|
||||||
return state_dict_increase
|
return state_dict_increase
|
||||||
|
|
||||||
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
|
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
|
||||||
def T(w):
|
def T(w):
|
||||||
return w.T if config.fan_in_fan_out else w
|
return w.T if config.fan_in_fan_out else w
|
||||||
|
|
||||||
if config.r > 0:
|
if config.r > 0:
|
||||||
scaling = config.lora_alpha / config.r
|
scaling = config.lora_alpha / config.r
|
||||||
weight_data_increase = T(lora_B @ lora_A) * scaling
|
weight_data_increase = T(lora_B @ lora_A) * scaling
|
||||||
@ -81,21 +80,21 @@ class LoRAConstructor:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
|
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
|
||||||
'''
|
"""
|
||||||
The final reconstruction step
|
The final reconstruction step
|
||||||
'''
|
"""
|
||||||
# naive approach
|
# naive approach
|
||||||
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
|
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
|
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
|
||||||
'''
|
"""
|
||||||
if keep_non_lora, also return non_lora state_dict
|
if keep_non_lora, also return non_lora state_dict
|
||||||
'''
|
"""
|
||||||
state_dict_lora = OrderedDict()
|
state_dict_lora = OrderedDict()
|
||||||
state_dict_non_lora = OrderedDict()
|
state_dict_non_lora = OrderedDict()
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if 'lora_A' in k or 'lora_B' in k:
|
if "lora_A" in k or "lora_B" in k:
|
||||||
state_dict_lora[k] = v
|
state_dict_lora[k] = v
|
||||||
elif keep_non_lora:
|
elif keep_non_lora:
|
||||||
state_dict_non_lora[k] = v
|
state_dict_non_lora[k] = v
|
||||||
@ -106,17 +105,19 @@ class LoRAConstructor:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
|
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
|
||||||
'''
|
"""
|
||||||
extract LoraLinear model.
|
extract LoraLinear model.
|
||||||
return OrderedDict(): name -> LoRAConfig
|
return OrderedDict(): name -> LoRAConfig
|
||||||
'''
|
"""
|
||||||
lora_config_dict = OrderedDict()
|
lora_config_dict = OrderedDict()
|
||||||
|
|
||||||
for name, child in model.named_modules():
|
for name, child in model.named_modules():
|
||||||
if isinstance(child, LoraLinear):
|
if isinstance(child, LoraLinear):
|
||||||
lora_config_dict[name] = LoRAConfig(r=child.r,
|
lora_config_dict[name] = LoRAConfig(
|
||||||
lora_alpha=child.lora_alpha,
|
r=child.r,
|
||||||
lora_dropout=child.lora_dropout,
|
lora_alpha=child.lora_alpha,
|
||||||
fan_in_fan_out=child.fan_in_fan_out)
|
lora_dropout=child.lora_dropout,
|
||||||
|
fan_in_fan_out=child.fan_in_fan_out,
|
||||||
|
)
|
||||||
|
|
||||||
return lora_config_dict
|
return lora_config_dict
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -10,7 +10,7 @@ from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
|||||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
|
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
|
||||||
|
|
||||||
|
|
||||||
def is_rank_0() -> bool:
|
def is_rank_0() -> bool:
|
||||||
@ -26,13 +26,13 @@ def get_world_size() -> int:
|
|||||||
|
|
||||||
|
|
||||||
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||||||
if model == 'gpt2':
|
if model == "gpt2":
|
||||||
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||||
elif model == 'bloom':
|
elif model == "bloom":
|
||||||
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||||
elif model == 'opt':
|
elif model == "opt":
|
||||||
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||||
elif model == 'llama':
|
elif model == "llama":
|
||||||
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported actor model "{model}"')
|
raise ValueError(f'Unsupported actor model "{model}"')
|
||||||
@ -40,13 +40,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
|
|||||||
|
|
||||||
|
|
||||||
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||||||
if model == 'gpt2':
|
if model == "gpt2":
|
||||||
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||||
elif model == 'bloom':
|
elif model == "bloom":
|
||||||
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||||
elif model == 'opt':
|
elif model == "opt":
|
||||||
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||||
elif model == 'llama':
|
elif model == "llama":
|
||||||
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported reward model "{model}"')
|
raise ValueError(f'Unsupported reward model "{model}"')
|
||||||
@ -54,13 +54,13 @@ def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_r
|
|||||||
|
|
||||||
|
|
||||||
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
||||||
if model == 'gpt2':
|
if model == "gpt2":
|
||||||
reward_model = GPTRM(pretrained=pretrained, config=config)
|
reward_model = GPTRM(pretrained=pretrained, config=config)
|
||||||
elif model == 'bloom':
|
elif model == "bloom":
|
||||||
reward_model = BLOOMRM(pretrained=pretrained, config=config)
|
reward_model = BLOOMRM(pretrained=pretrained, config=config)
|
||||||
elif model == 'opt':
|
elif model == "opt":
|
||||||
reward_model = OPTRM(pretrained=pretrained, config=config)
|
reward_model = OPTRM(pretrained=pretrained, config=config)
|
||||||
elif model == 'llama':
|
elif model == "llama":
|
||||||
reward_model = LlamaRM(pretrained=pretrained, config=config)
|
reward_model = LlamaRM(pretrained=pretrained, config=config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported reward model "{model}"')
|
raise ValueError(f'Unsupported reward model "{model}"')
|
||||||
@ -68,29 +68,29 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
|||||||
|
|
||||||
|
|
||||||
def get_strategy_from_args(strategy: str):
|
def get_strategy_from_args(strategy: str):
|
||||||
if strategy == 'ddp':
|
if strategy == "ddp":
|
||||||
strategy_ = DDPStrategy()
|
strategy_ = DDPStrategy()
|
||||||
elif strategy == 'colossalai_gemini':
|
elif strategy == "colossalai_gemini":
|
||||||
strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||||
elif strategy == 'colossalai_zero2':
|
elif strategy == "colossalai_zero2":
|
||||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||||
elif strategy == 'colossalai_gemini_cpu':
|
elif strategy == "colossalai_gemini_cpu":
|
||||||
strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
|
strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||||
elif strategy == 'colossalai_zero2_cpu':
|
elif strategy == "colossalai_zero2_cpu":
|
||||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||||
return strategy_
|
return strategy_
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer_from_args(model: str, **kwargs):
|
def get_tokenizer_from_args(model: str, **kwargs):
|
||||||
if model == 'gpt2':
|
if model == "gpt2":
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
elif model == 'bloom':
|
elif model == "bloom":
|
||||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
|
||||||
elif model == 'opt':
|
elif model == "opt":
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
elif model == 'llama':
|
elif model == "llama":
|
||||||
pretrain_path = kwargs["pretrain"]
|
pretrain_path = kwargs["pretrain"]
|
||||||
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
|
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
|
||||||
else:
|
else:
|
||||||
@ -101,11 +101,11 @@ def get_tokenizer_from_args(model: str, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def set_dist_env(env_info: Dict[str, str]):
|
def set_dist_env(env_info: Dict[str, str]):
|
||||||
os.environ["RANK"] = env_info['rank']
|
os.environ["RANK"] = env_info["rank"]
|
||||||
os.environ["LOCAL_RANK"] = env_info['local_rank']
|
os.environ["LOCAL_RANK"] = env_info["local_rank"]
|
||||||
os.environ["WORLD_SIZE"] = env_info['world_size']
|
os.environ["WORLD_SIZE"] = env_info["world_size"]
|
||||||
os.environ['MASTER_PORT'] = env_info['master_port']
|
os.environ["MASTER_PORT"] = env_info["master_port"]
|
||||||
os.environ['MASTER_ADDR'] = env_info['master_addr']
|
os.environ["MASTER_ADDR"] = env_info["master_addr"]
|
||||||
|
|
||||||
|
|
||||||
def get_model_numel(model: nn.Module) -> int:
|
def get_model_numel(model: nn.Module) -> int:
|
||||||
@ -128,12 +128,12 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i
|
|||||||
return target_receivers
|
return target_receivers
|
||||||
|
|
||||||
|
|
||||||
def state_dict_to(state_dict: Dict[str, Any],
|
def state_dict_to(
|
||||||
dtype: torch.dtype = torch.float16,
|
state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
|
||||||
device: torch.device = torch.device('cpu')):
|
):
|
||||||
'''
|
"""
|
||||||
keep state_dict intact
|
keep state_dict intact
|
||||||
'''
|
"""
|
||||||
new_state_dict = OrderedDict()
|
new_state_dict = OrderedDict()
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
new_state_dict[k] = v.to(dtype=dtype, device=device)
|
new_state_dict[k] = v.to(dtype=dtype, device=device)
|
||||||
|
@ -3,8 +3,4 @@ from .ppo import PPOTrainer
|
|||||||
from .rm import RewardModelTrainer
|
from .rm import RewardModelTrainer
|
||||||
from .sft import SFTTrainer
|
from .sft import SFTTrainer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"]
|
||||||
'SLTrainer', 'OnPolicyTrainer',
|
|
||||||
'RewardModelTrainer', 'SFTTrainer',
|
|
||||||
'PPOTrainer'
|
|
||||||
]
|
|
||||||
|
@ -68,12 +68,14 @@ class OnPolicyTrainer(ABC):
|
|||||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
strategy: Strategy,
|
self,
|
||||||
data_buffer: NaiveExperienceBuffer,
|
strategy: Strategy,
|
||||||
sample_buffer: bool,
|
data_buffer: NaiveExperienceBuffer,
|
||||||
dataloader_pin_memory: bool,
|
sample_buffer: bool,
|
||||||
callbacks: List[Callback] = []) -> None:
|
dataloader_pin_memory: bool,
|
||||||
|
callbacks: List[Callback] = [],
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
self.data_buffer = data_buffer
|
self.data_buffer = data_buffer
|
||||||
|
@ -2,4 +2,4 @@ from .base import Callback
|
|||||||
from .performance_evaluator import PerformanceEvaluator
|
from .performance_evaluator import PerformanceEvaluator
|
||||||
from .save_checkpoint import SaveCheckpoint
|
from .save_checkpoint import SaveCheckpoint
|
||||||
|
|
||||||
__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
|
__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"]
|
||||||
|
@ -5,7 +5,7 @@ from coati.experience_maker import Experience
|
|||||||
|
|
||||||
class Callback(ABC):
|
class Callback(ABC):
|
||||||
"""
|
"""
|
||||||
Base callback class. It defines the interface for callbacks.
|
Base callback class. It defines the interface for callbacks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def on_fit_start(self) -> None:
|
def on_fit_start(self) -> None:
|
||||||
|
@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None:
|
|||||||
|
|
||||||
def divide(x: float, y: float) -> float:
|
def divide(x: float, y: float) -> float:
|
||||||
if y == 0:
|
if y == 0:
|
||||||
return float('inf')
|
return float("inf")
|
||||||
elif y == float('inf'):
|
elif y == float("inf"):
|
||||||
return float('nan')
|
return float("nan")
|
||||||
return x / y
|
return x / y
|
||||||
|
|
||||||
|
|
||||||
@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
|||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.start_time: Optional[float] = None
|
self.start_time: Optional[float] = None
|
||||||
self.duration: float = 0.
|
self.duration: float = 0.0
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
self.start_time = time()
|
self.start_time = time()
|
||||||
@ -52,7 +51,7 @@ class Timer:
|
|||||||
self.start_time = None
|
self.start_time = None
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.duration = 0.
|
self.duration = 0.0
|
||||||
|
|
||||||
|
|
||||||
class PerformanceEvaluator(Callback):
|
class PerformanceEvaluator(Callback):
|
||||||
@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback):
|
|||||||
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
actor_num_params: int,
|
self,
|
||||||
critic_num_params: int,
|
actor_num_params: int,
|
||||||
initial_model_num_params: int,
|
critic_num_params: int,
|
||||||
reward_model_num_params: int,
|
initial_model_num_params: int,
|
||||||
enable_grad_checkpoint: bool = False,
|
reward_model_num_params: int,
|
||||||
ignore_episodes: int = 0) -> None:
|
enable_grad_checkpoint: bool = False,
|
||||||
|
ignore_episodes: int = 0,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.world_size = get_world_size()
|
self.world_size = get_world_size()
|
||||||
self.actor_num_params = actor_num_params
|
self.actor_num_params = actor_num_params
|
||||||
@ -155,8 +156,9 @@ class PerformanceEvaluator(Callback):
|
|||||||
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size)
|
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size)
|
||||||
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size)
|
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size)
|
||||||
|
|
||||||
avg_make_experience_throughput = self.make_experience_num_samples * \
|
avg_make_experience_throughput = (
|
||||||
self.world_size / (avg_make_experience_duration + 1e-12)
|
self.make_experience_num_samples * self.world_size / (avg_make_experience_duration + 1e-12)
|
||||||
|
)
|
||||||
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
||||||
|
|
||||||
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12)
|
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12)
|
||||||
@ -171,13 +173,11 @@ class PerformanceEvaluator(Callback):
|
|||||||
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
|
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
|
||||||
|
|
||||||
print_rank_0(
|
print_rank_0(
|
||||||
f'Performance summary:\n'
|
f"Performance summary:\n"
|
||||||
+ f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
|
+ f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n"
|
||||||
|
+ f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n"
|
||||||
+ f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
|
+ f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n"
|
||||||
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n'
|
+ f"Overall time per sample: {overall_time_per_sample:.2f} s\n"
|
||||||
+ f'Overall time per sample: {overall_time_per_sample:.2f} s\n'
|
+ f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n"
|
||||||
+ f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
|
+ f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%"
|
||||||
|
|
||||||
+ f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
|
|
||||||
)
|
)
|
||||||
|
@ -36,34 +36,35 @@ class SaveCheckpoint(Callback):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
path: str,
|
self,
|
||||||
interval: int,
|
path: str,
|
||||||
strategy: Strategy,
|
interval: int,
|
||||||
actor: nn.Module = None,
|
strategy: Strategy,
|
||||||
critic: nn.Module = None,
|
actor: nn.Module = None,
|
||||||
actor_optim: Optimizer = None,
|
critic: nn.Module = None,
|
||||||
critic_optim: Optimizer = None) -> None:
|
actor_optim: Optimizer = None,
|
||||||
|
critic_optim: Optimizer = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.path = os.path.join(path, 'checkpoint')
|
self.path = os.path.join(path, "checkpoint")
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
|
self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]}
|
||||||
|
|
||||||
def on_episode_end(self, episode: int) -> None:
|
def on_episode_end(self, episode: int) -> None:
|
||||||
if (episode + 1) % self.interval != 0:
|
if (episode + 1) % self.interval != 0:
|
||||||
return
|
return
|
||||||
base_path = os.path.join(self.path, f'episode_{episode}')
|
base_path = os.path.join(self.path, f"episode_{episode}")
|
||||||
if not os.path.exists(base_path):
|
if not os.path.exists(base_path):
|
||||||
os.makedirs(base_path)
|
os.makedirs(base_path)
|
||||||
|
|
||||||
for model in self.model_dict.keys():
|
for model in self.model_dict.keys():
|
||||||
|
|
||||||
# save model
|
# save model
|
||||||
if self.model_dict[model][0] is None:
|
if self.model_dict[model][0] is None:
|
||||||
# saving only optimizer states is meaningless, so it would be skipped
|
# saving only optimizer states is meaningless, so it would be skipped
|
||||||
continue
|
continue
|
||||||
model_path = os.path.join(base_path, f'{model}.pt')
|
model_path = os.path.join(base_path, f"{model}.pt")
|
||||||
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
|
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
|
||||||
|
|
||||||
# save optimizer
|
# save optimizer
|
||||||
@ -71,5 +72,5 @@ class SaveCheckpoint(Callback):
|
|||||||
continue
|
continue
|
||||||
only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
|
only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
|
||||||
rank = 0 if is_rank_0() else dist.get_rank()
|
rank = 0 if is_rank_0() else dist.get_rank()
|
||||||
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
|
optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt")
|
||||||
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
|
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
|
||||||
|
@ -8,7 +8,7 @@ from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
|||||||
from coati.models.utils import calc_action_log_probs
|
from coati.models.utils import calc_action_log_probs
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
@ -24,11 +24,11 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto
|
|||||||
hf_model = get_base_model(unwrapper_model)
|
hf_model = get_base_model(unwrapper_model)
|
||||||
new_kwargs = {**generate_kwargs}
|
new_kwargs = {**generate_kwargs}
|
||||||
# use huggingface models method directly
|
# use huggingface models method directly
|
||||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
|
if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
|
||||||
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
|
new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
|
||||||
|
|
||||||
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
|
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
|
||||||
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
|
new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
|
||||||
|
|
||||||
return new_kwargs
|
return new_kwargs
|
||||||
|
|
||||||
@ -60,38 +60,34 @@ class PPOTrainer(OnPolicyTrainer):
|
|||||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
strategy: Strategy,
|
self,
|
||||||
actor: Actor,
|
strategy: Strategy,
|
||||||
critic: Critic,
|
actor: Actor,
|
||||||
reward_model: nn.Module,
|
critic: Critic,
|
||||||
initial_model: Actor,
|
reward_model: nn.Module,
|
||||||
actor_optim: Optimizer,
|
initial_model: Actor,
|
||||||
critic_optim: Optimizer,
|
actor_optim: Optimizer,
|
||||||
kl_coef: float = 0.1,
|
critic_optim: Optimizer,
|
||||||
ptx_coef: float = 0.9,
|
kl_coef: float = 0.1,
|
||||||
train_batch_size: int = 8,
|
ptx_coef: float = 0.9,
|
||||||
buffer_limit: int = 0,
|
train_batch_size: int = 8,
|
||||||
buffer_cpu_offload: bool = True,
|
buffer_limit: int = 0,
|
||||||
eps_clip: float = 0.2,
|
buffer_cpu_offload: bool = True,
|
||||||
vf_coef: float = 1.0,
|
eps_clip: float = 0.2,
|
||||||
value_clip: float = 0.4,
|
vf_coef: float = 1.0,
|
||||||
sample_buffer: bool = False,
|
value_clip: float = 0.4,
|
||||||
dataloader_pin_memory: bool = True,
|
sample_buffer: bool = False,
|
||||||
offload_inference_models: bool = True,
|
dataloader_pin_memory: bool = True,
|
||||||
callbacks: List[Callback] = [],
|
offload_inference_models: bool = True,
|
||||||
**generate_kwargs
|
callbacks: List[Callback] = [],
|
||||||
) -> None:
|
**generate_kwargs,
|
||||||
|
) -> None:
|
||||||
if isinstance(strategy, GeminiStrategy):
|
if isinstance(strategy, GeminiStrategy):
|
||||||
assert not offload_inference_models, \
|
assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
|
||||||
"GeminiPlugin is not compatible with manual model.to('cpu')"
|
|
||||||
|
|
||||||
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||||
super().__init__(
|
super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
|
||||||
strategy, data_buffer,
|
|
||||||
sample_buffer, dataloader_pin_memory,
|
|
||||||
callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||||
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||||
@ -130,18 +126,16 @@ class PPOTrainer(OnPolicyTrainer):
|
|||||||
num_actions = experience.action_mask.size(1)
|
num_actions = experience.action_mask.size(1)
|
||||||
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
|
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
|
||||||
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
|
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
|
||||||
actor_loss = self.actor_loss_fn(action_log_probs,
|
actor_loss = self.actor_loss_fn(
|
||||||
experience.action_log_probs,
|
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||||
experience.advantages,
|
)
|
||||||
action_mask=experience.action_mask)
|
|
||||||
|
|
||||||
# ptx loss
|
# ptx loss
|
||||||
if self.ptx_coef != 0:
|
if self.ptx_coef != 0:
|
||||||
batch = self.pretrain_dataloader.next()
|
batch = self.pretrain_dataloader.next()
|
||||||
batch = to_device(batch, self.device)
|
batch = to_device(batch, self.device)
|
||||||
ptx_log_probs = self.actor(batch['input_ids'],
|
ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"]
|
||||||
attention_mask=batch['attention_mask'])['logits']
|
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"])
|
||||||
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
|
|
||||||
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
||||||
|
|
||||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||||
@ -149,24 +143,23 @@ class PPOTrainer(OnPolicyTrainer):
|
|||||||
self.actor_optim.zero_grad()
|
self.actor_optim.zero_grad()
|
||||||
|
|
||||||
# value loss
|
# value loss
|
||||||
values = self.critic(experience.sequences,
|
values = self.critic(
|
||||||
action_mask=experience.action_mask,
|
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
|
||||||
attention_mask=experience.attention_mask)
|
)
|
||||||
critic_loss = self.critic_loss_fn(values,
|
critic_loss = self.critic_loss_fn(
|
||||||
experience.values,
|
values, experience.values, experience.reward, action_mask=experience.action_mask
|
||||||
experience.reward,
|
)
|
||||||
action_mask=experience.action_mask)
|
|
||||||
critic_loss = critic_loss * self.vf_coef
|
critic_loss = critic_loss * self.vf_coef
|
||||||
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
||||||
self.strategy.optimizer_step(self.critic_optim)
|
self.strategy.optimizer_step(self.critic_optim)
|
||||||
self.critic_optim.zero_grad()
|
self.critic_optim.zero_grad()
|
||||||
|
|
||||||
return {'reward': experience.reward.mean().item()}
|
return {"reward": experience.reward.mean().item()}
|
||||||
|
|
||||||
def _learn(self, update_step: int):
|
def _learn(self, update_step: int):
|
||||||
if self.offload_inference_models:
|
if self.offload_inference_models:
|
||||||
self.experience_maker.initial_model.to('cpu')
|
self.experience_maker.initial_model.to("cpu")
|
||||||
self.experience_maker.reward_model.to('cpu')
|
self.experience_maker.reward_model.to("cpu")
|
||||||
|
|
||||||
# buffer may be empty at first, we should rebuild at each training
|
# buffer may be empty at first, we should rebuild at each training
|
||||||
if self.sample_buffer:
|
if self.sample_buffer:
|
||||||
@ -178,11 +171,7 @@ class PPOTrainer(OnPolicyTrainer):
|
|||||||
else:
|
else:
|
||||||
if isinstance(self.dataloader.sampler, DistributedSampler):
|
if isinstance(self.dataloader.sampler, DistributedSampler):
|
||||||
self.dataloader.sampler.set_epoch(update_step)
|
self.dataloader.sampler.set_epoch(update_step)
|
||||||
pbar = tqdm(
|
pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
|
||||||
self.dataloader,
|
|
||||||
desc=f'Train epoch [{update_step + 1}]',
|
|
||||||
disable=not is_rank_0()
|
|
||||||
)
|
|
||||||
for experience in pbar:
|
for experience in pbar:
|
||||||
self._on_learn_batch_start()
|
self._on_learn_batch_start()
|
||||||
experience.to_device(self.device)
|
experience.to_device(self.device)
|
||||||
|
@ -62,18 +62,15 @@ class RewardModelTrainer(SLTrainer):
|
|||||||
|
|
||||||
if is_rank_0():
|
if is_rank_0():
|
||||||
log = pd.DataFrame(
|
log = pd.DataFrame(
|
||||||
[[(epoch + 1) * len(self.train_dataloader),
|
[[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]],
|
||||||
self.loss.item(), self.dist, self.acc]],
|
columns=["step", "loss", "dist", "acc"],
|
||||||
columns=['step', 'loss', 'dist', 'acc']
|
|
||||||
)
|
)
|
||||||
log.to_csv('log.csv', mode='a', header=False, index=False)
|
log.to_csv("log.csv", mode="a", header=False, index=False)
|
||||||
|
|
||||||
def _train(self, epoch):
|
def _train(self, epoch):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
step_bar = tqdm.trange(
|
step_bar = tqdm.trange(
|
||||||
len(self.train_dataloader),
|
len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0()
|
||||||
desc='Train step of epoch %d' % epoch,
|
|
||||||
disable=not is_rank_0()
|
|
||||||
)
|
)
|
||||||
cnt = 0
|
cnt = 0
|
||||||
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
||||||
@ -93,10 +90,7 @@ class RewardModelTrainer(SLTrainer):
|
|||||||
step_bar.update()
|
step_bar.update()
|
||||||
step_bar.close()
|
step_bar.close()
|
||||||
|
|
||||||
def _before_fit(self,
|
def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader):
|
||||||
train_dataloader: DataLoader,
|
|
||||||
valid_dataloader: DataLoader,
|
|
||||||
eval_dataloader: DataLoader):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
train_dataloader (DataLoader): the dataloader to use for training
|
train_dataloader (DataLoader): the dataloader to use for training
|
||||||
@ -104,7 +98,7 @@ class RewardModelTrainer(SLTrainer):
|
|||||||
eval_dataloader (DataLoader): the dataloader to use for evaluation
|
eval_dataloader (DataLoader): the dataloader to use for evaluation
|
||||||
"""
|
"""
|
||||||
super()._before_fit()
|
super()._before_fit()
|
||||||
self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
|
||||||
self.train_dataloader = train_dataloader
|
self.train_dataloader = train_dataloader
|
||||||
self.valid_dataloader = valid_dataloader
|
self.valid_dataloader = valid_dataloader
|
||||||
|
@ -39,8 +39,9 @@ class SFTTrainer(SLTrainer):
|
|||||||
accumulation_steps: int = 8,
|
accumulation_steps: int = 8,
|
||||||
) -> None:
|
) -> None:
|
||||||
if accumulation_steps > 1:
|
if accumulation_steps > 1:
|
||||||
assert not isinstance(strategy, GeminiStrategy), \
|
assert not isinstance(
|
||||||
"Accumulation steps are not supported in stage 3 of ColossalAI"
|
strategy, GeminiStrategy
|
||||||
|
), "Accumulation steps are not supported in stage 3 of ColossalAI"
|
||||||
|
|
||||||
super().__init__(strategy, max_epochs, model, optim)
|
super().__init__(strategy, max_epochs, model, optim)
|
||||||
|
|
||||||
@ -50,15 +51,11 @@ class SFTTrainer(SLTrainer):
|
|||||||
def _train(self, epoch: int):
|
def _train(self, epoch: int):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
for batch_id, batch in enumerate(self.train_dataloader):
|
for batch_id, batch in enumerate(self.train_dataloader):
|
||||||
|
|
||||||
batch = to_device(batch, torch.cuda.current_device())
|
batch = to_device(batch, torch.cuda.current_device())
|
||||||
if "attention_mask" in batch:
|
if "attention_mask" in batch:
|
||||||
outputs = self.model(batch["input_ids"],
|
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||||
attention_mask=batch["attention_mask"],
|
|
||||||
labels=batch["labels"])
|
|
||||||
else:
|
else:
|
||||||
outputs = self.model(batch["input_ids"],
|
outputs = self.model(batch["input_ids"], labels=batch["labels"])
|
||||||
labels=batch["labels"])
|
|
||||||
|
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
loss = loss / self.accumulation_steps
|
loss = loss / self.accumulation_steps
|
||||||
@ -73,12 +70,14 @@ class SFTTrainer(SLTrainer):
|
|||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
if is_rank_0() and self.use_wandb:
|
if is_rank_0() and self.use_wandb:
|
||||||
wandb.log({
|
wandb.log(
|
||||||
"loss": self.total_loss / self.accumulation_steps,
|
{
|
||||||
"lr": self.scheduler.get_last_lr()[0],
|
"loss": self.total_loss / self.accumulation_steps,
|
||||||
"epoch": epoch,
|
"lr": self.scheduler.get_last_lr()[0],
|
||||||
"batch_id": batch_id
|
"epoch": epoch,
|
||||||
})
|
"batch_id": batch_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
self.total_loss = 0
|
self.total_loss = 0
|
||||||
self.step_bar.update()
|
self.step_bar.update()
|
||||||
|
|
||||||
@ -89,9 +88,9 @@ class SFTTrainer(SLTrainer):
|
|||||||
loss_sum, num_seen = 0, 0
|
loss_sum, num_seen = 0, 0
|
||||||
for batch in self.eval_dataloader:
|
for batch in self.eval_dataloader:
|
||||||
batch = to_device(batch, torch.cuda.current_device())
|
batch = to_device(batch, torch.cuda.current_device())
|
||||||
outputs = self.model(batch["input_ids"],
|
outputs = self.model(
|
||||||
attention_mask=batch["attention_mask"],
|
batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
|
||||||
labels=batch["labels"])
|
)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
|
||||||
loss_sum += loss.item()
|
loss_sum += loss.item()
|
||||||
@ -99,13 +98,15 @@ class SFTTrainer(SLTrainer):
|
|||||||
|
|
||||||
loss_mean = loss_sum / num_seen
|
loss_mean = loss_sum / num_seen
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
self.logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
|
self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
|
||||||
|
|
||||||
def _before_fit(self,
|
def _before_fit(
|
||||||
train_dataloader: DataLoader,
|
self,
|
||||||
eval_dataloader: Optional[DataLoader] = None,
|
train_dataloader: DataLoader,
|
||||||
logger: Optional[DistributedLogger] = None,
|
eval_dataloader: Optional[DataLoader] = None,
|
||||||
use_wandb: bool = False):
|
logger: Optional[DistributedLogger] = None,
|
||||||
|
use_wandb: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
train_dataloader: the dataloader to use for training
|
train_dataloader: the dataloader to use for training
|
||||||
@ -124,6 +125,6 @@ class SFTTrainer(SLTrainer):
|
|||||||
self.no_epoch_bar = True
|
self.no_epoch_bar = True
|
||||||
self.step_bar = tqdm.trange(
|
self.step_bar = tqdm.trange(
|
||||||
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
|
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
|
||||||
desc=f'steps',
|
desc=f"steps",
|
||||||
disable=not is_rank_0()
|
disable=not is_rank_0(),
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,4 @@ from .base import Strategy
|
|||||||
from .colossalai import GeminiStrategy, LowLevelZeroStrategy
|
from .colossalai import GeminiStrategy, LowLevelZeroStrategy
|
||||||
from .ddp import DDPStrategy
|
from .ddp import DDPStrategy
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"]
|
||||||
'Strategy', 'DDPStrategy',
|
|
||||||
'LowLevelZeroStrategy', 'GeminiStrategy'
|
|
||||||
]
|
|
||||||
|
@ -19,7 +19,7 @@ _BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
|
|||||||
|
|
||||||
class Strategy(ABC):
|
class Strategy(ABC):
|
||||||
"""
|
"""
|
||||||
Base class for training strategies.
|
Base class for training strategies.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
|
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
|
||||||
@ -83,16 +83,18 @@ class Strategy(ABC):
|
|||||||
rets.append((model, optimizer))
|
rets.append((model, optimizer))
|
||||||
elif isinstance(arg, Dict):
|
elif isinstance(arg, Dict):
|
||||||
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
|
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
|
||||||
boost_result = dict(model=model,
|
boost_result = dict(
|
||||||
optimizer=optimizer,
|
model=model,
|
||||||
criterion=criterion,
|
optimizer=optimizer,
|
||||||
dataloader=dataloader,
|
criterion=criterion,
|
||||||
lr_scheduler=lr_scheduler)
|
dataloader=dataloader,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
)
|
||||||
# remove None values
|
# remove None values
|
||||||
boost_result = {key: value for key, value in boost_result.items() if value is not None}
|
boost_result = {key: value for key, value in boost_result.items() if value is not None}
|
||||||
rets.append(boost_result)
|
rets.append(boost_result)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f'Type {type(arg)} is not supported')
|
raise RuntimeError(f"Type {type(arg)} is not supported")
|
||||||
|
|
||||||
return rets[0] if len(rets) == 1 else rets
|
return rets[0] if len(rets) == 1 else rets
|
||||||
|
|
||||||
@ -125,11 +127,9 @@ class Strategy(ABC):
|
|||||||
return DistributedSampler(dataset, 1, 0)
|
return DistributedSampler(dataset, 1, 0)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_pretrained(self,
|
def save_pretrained(
|
||||||
model: nn.Module,
|
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
|
||||||
path: str,
|
) -> None:
|
||||||
only_rank0: bool = True,
|
|
||||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -42,27 +42,27 @@ class LowLevelZeroStrategy(DDPStrategy):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
stage: int = 2,
|
self,
|
||||||
precision: str = 'fp16',
|
stage: int = 2,
|
||||||
seed: int = 42,
|
precision: str = "fp16",
|
||||||
placement_policy: str = 'cuda',
|
seed: int = 42,
|
||||||
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
|
placement_policy: str = "cuda",
|
||||||
overlap_communication: bool = True, # only for stage 1&2
|
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
|
||||||
initial_scale: float = 2**16,
|
overlap_communication: bool = True, # only for stage 1&2
|
||||||
growth_factor: float = 2,
|
initial_scale: float = 2**16,
|
||||||
backoff_factor: float = 0.5,
|
growth_factor: float = 2,
|
||||||
growth_interval: int = 1000,
|
backoff_factor: float = 0.5,
|
||||||
hysteresis: int = 2,
|
growth_interval: int = 1000,
|
||||||
min_scale: float = 1,
|
hysteresis: int = 2,
|
||||||
max_scale: float = 2**32,
|
min_scale: float = 1,
|
||||||
max_norm: float = 0.0,
|
max_scale: float = 2**32,
|
||||||
norm_type: float = 2.0
|
max_norm: float = 0.0,
|
||||||
) -> None:
|
norm_type: float = 2.0,
|
||||||
|
) -> None:
|
||||||
assert stage in (1, 2), f'Unsupported stage "{stage}"'
|
assert stage in (1, 2), f'Unsupported stage "{stage}"'
|
||||||
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
|
||||||
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
|
assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
|
||||||
|
|
||||||
plugin_initializer = lambda: LowLevelZeroPlugin(
|
plugin_initializer = lambda: LowLevelZeroPlugin(
|
||||||
# zero_config
|
# zero_config
|
||||||
@ -71,7 +71,7 @@ class LowLevelZeroStrategy(DDPStrategy):
|
|||||||
# zero_optim_config
|
# zero_optim_config
|
||||||
reduce_bucket_size_in_m=reduce_bucket_size,
|
reduce_bucket_size_in_m=reduce_bucket_size,
|
||||||
overlap_communication=overlap_communication,
|
overlap_communication=overlap_communication,
|
||||||
cpu_offload=(placement_policy == 'cpu'),
|
cpu_offload=(placement_policy == "cpu"),
|
||||||
# optim_config
|
# optim_config
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
growth_factor=growth_factor,
|
growth_factor=growth_factor,
|
||||||
@ -81,14 +81,15 @@ class LowLevelZeroStrategy(DDPStrategy):
|
|||||||
min_scale=min_scale,
|
min_scale=min_scale,
|
||||||
max_scale=max_scale,
|
max_scale=max_scale,
|
||||||
max_norm=max_norm,
|
max_norm=max_norm,
|
||||||
norm_type=norm_type
|
norm_type=norm_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(seed, plugin_initializer)
|
super().__init__(seed, plugin_initializer)
|
||||||
|
|
||||||
def _post_init(self) -> None:
|
def _post_init(self) -> None:
|
||||||
assert isinstance(self.plugin, LowLevelZeroPlugin), \
|
assert isinstance(
|
||||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
self.plugin, LowLevelZeroPlugin
|
||||||
|
), f"{type(self).__name__}'s plugin is not initialized properly."
|
||||||
|
|
||||||
def setup_distributed(self) -> None:
|
def setup_distributed(self) -> None:
|
||||||
colossalai.launch_from_torch({}, seed=self.seed)
|
colossalai.launch_from_torch({}, seed=self.seed)
|
||||||
@ -131,45 +132,45 @@ class GeminiStrategy(DDPStrategy):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
seed: int = 42,
|
self,
|
||||||
shard_init: bool = False, # only for stage 3
|
seed: int = 42,
|
||||||
placement_policy: str = 'cuda',
|
shard_init: bool = False, # only for stage 3
|
||||||
pin_memory: bool = True, # only for stage 3
|
placement_policy: str = "cuda",
|
||||||
force_outputs_fp32: bool = False, # only for stage 3
|
pin_memory: bool = True, # only for stage 3
|
||||||
search_range_m: int = 32, # only for stage 3
|
force_outputs_fp32: bool = False, # only for stage 3
|
||||||
hidden_dim: Optional[int] = None, # only for stage 3
|
search_range_m: int = 32, # only for stage 3
|
||||||
min_chunk_size_m: float = 32, # only for stage 3
|
hidden_dim: Optional[int] = None, # only for stage 3
|
||||||
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
|
min_chunk_size_m: float = 32, # only for stage 3
|
||||||
initial_scale: float = 2**16,
|
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
|
||||||
growth_factor: float = 2,
|
initial_scale: float = 2**16,
|
||||||
backoff_factor: float = 0.5,
|
growth_factor: float = 2,
|
||||||
growth_interval: int = 1000,
|
backoff_factor: float = 0.5,
|
||||||
hysteresis: int = 2,
|
growth_interval: int = 1000,
|
||||||
min_scale: float = 1,
|
hysteresis: int = 2,
|
||||||
max_scale: float = 2**32,
|
min_scale: float = 1,
|
||||||
max_norm: float = 0.0,
|
max_scale: float = 2**32,
|
||||||
norm_type: float = 2.0
|
max_norm: float = 0.0,
|
||||||
) -> None:
|
norm_type: float = 2.0,
|
||||||
|
) -> None:
|
||||||
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
|
||||||
|
|
||||||
# TODO(ver217): support shard_init when using from_pretrained()
|
# TODO(ver217): support shard_init when using from_pretrained()
|
||||||
if shard_init:
|
if shard_init:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f'Shard init is not supported model.from_pretrained() yet. '
|
f"Shard init is not supported model.from_pretrained() yet. "
|
||||||
'Please load weights after strategy.prepare()'
|
"Please load weights after strategy.prepare()"
|
||||||
)
|
)
|
||||||
self.shard_init = shard_init
|
self.shard_init = shard_init
|
||||||
|
|
||||||
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
|
warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
|
||||||
|
|
||||||
# NOTE: dist should be initialized before calling get_current_device()
|
# NOTE: dist should be initialized before calling get_current_device()
|
||||||
plugin_initializer = lambda: GeminiPlugin(
|
plugin_initializer = lambda: GeminiPlugin(
|
||||||
# gemini_config
|
# gemini_config
|
||||||
device=get_current_device(),
|
device=get_current_device(),
|
||||||
placement_policy=placement_policy,
|
placement_policy=placement_policy,
|
||||||
precision='fp16',
|
precision="fp16",
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
force_outputs_fp32=force_outputs_fp32,
|
force_outputs_fp32=force_outputs_fp32,
|
||||||
strict_ddp_mode=shard_init,
|
strict_ddp_mode=shard_init,
|
||||||
@ -187,14 +188,13 @@ class GeminiStrategy(DDPStrategy):
|
|||||||
min_scale=min_scale,
|
min_scale=min_scale,
|
||||||
max_scale=max_scale,
|
max_scale=max_scale,
|
||||||
max_norm=max_norm,
|
max_norm=max_norm,
|
||||||
norm_type=norm_type
|
norm_type=norm_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(seed, plugin_initializer)
|
super().__init__(seed, plugin_initializer)
|
||||||
|
|
||||||
def _post_init(self) -> None:
|
def _post_init(self) -> None:
|
||||||
assert isinstance(self.plugin, GeminiPlugin), \
|
assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
|
||||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
|
||||||
|
|
||||||
def setup_distributed(self) -> None:
|
def setup_distributed(self) -> None:
|
||||||
colossalai.launch_from_torch({}, seed=self.seed)
|
colossalai.launch_from_torch({}, seed=self.seed)
|
||||||
@ -203,10 +203,9 @@ class GeminiStrategy(DDPStrategy):
|
|||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
|
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
|
||||||
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
|
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
|
||||||
return ColoInitContext(device=get_current_device(),
|
return ColoInitContext(
|
||||||
dtype=torch.half,
|
device=get_current_device(), dtype=torch.half, default_pg=shard_pg, default_dist_spec=default_dist_spec
|
||||||
default_pg=shard_pg,
|
)
|
||||||
default_dist_spec=default_dist_spec)
|
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||||
assert isinstance(model, GeminiModel)
|
assert isinstance(model, GeminiModel)
|
||||||
|
@ -31,24 +31,21 @@ def get_grad_required_state_dict(model: nn.Module):
|
|||||||
|
|
||||||
class DDPStrategy(Strategy):
|
class DDPStrategy(Strategy):
|
||||||
"""
|
"""
|
||||||
Strategy for distributed training using torch.distributed.
|
Strategy for distributed training using torch.distributed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None:
|
||||||
seed: int = 42,
|
|
||||||
plugin_initializer: Callable = TorchDDPPlugin
|
|
||||||
) -> None:
|
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
super().__init__(plugin_initializer)
|
super().__init__(plugin_initializer)
|
||||||
|
|
||||||
def _try_init_dist(self, force: bool = False) -> None:
|
def _try_init_dist(self, force: bool = False) -> None:
|
||||||
try:
|
try:
|
||||||
rank = int(os.environ['RANK'])
|
rank = int(os.environ["RANK"])
|
||||||
local_rank = int(os.environ['LOCAL_RANK'])
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
world_size = int(os.environ['WORLD_SIZE'])
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
host = os.environ['MASTER_ADDR']
|
host = os.environ["MASTER_ADDR"]
|
||||||
port = int(os.environ['MASTER_PORT'])
|
port = int(os.environ["MASTER_PORT"])
|
||||||
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
|
dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank)
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
if force:
|
if force:
|
||||||
@ -60,8 +57,7 @@ class DDPStrategy(Strategy):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _post_init(self) -> None:
|
def _post_init(self) -> None:
|
||||||
assert isinstance(self.plugin, TorchDDPPlugin), \
|
assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
|
||||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
|
||||||
|
|
||||||
def setup_distributed(self) -> None:
|
def setup_distributed(self) -> None:
|
||||||
self._try_init_dist(force=True)
|
self._try_init_dist(force=True)
|
||||||
@ -73,12 +69,14 @@ class DDPStrategy(Strategy):
|
|||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
|
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
|
||||||
return self.plugin.prepare_dataloader(data_buffer,
|
return self.plugin.prepare_dataloader(
|
||||||
batch_size=data_buffer.sample_batch_size,
|
data_buffer,
|
||||||
shuffle=True,
|
batch_size=data_buffer.sample_batch_size,
|
||||||
drop_last=True,
|
shuffle=True,
|
||||||
pin_memory=pin_memory,
|
drop_last=True,
|
||||||
collate_fn=data_buffer.collate_fn)
|
pin_memory=pin_memory,
|
||||||
|
collate_fn=data_buffer.collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
def setup_sampler(self, dataset) -> DistributedSampler:
|
def setup_sampler(self, dataset) -> DistributedSampler:
|
||||||
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
|
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
|
||||||
@ -88,11 +86,9 @@ class DDPStrategy(Strategy):
|
|||||||
assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
|
assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
|
||||||
return model.unwrap()
|
return model.unwrap()
|
||||||
|
|
||||||
def save_pretrained(self,
|
def save_pretrained(
|
||||||
model: nn.Module,
|
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
|
||||||
path: str,
|
) -> None:
|
||||||
only_rank0: bool = True,
|
|
||||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
|
||||||
if not only_rank0 or dist.get_rank() == 0:
|
if not only_rank0 or dist.get_rank() == 0:
|
||||||
unwrapped_model = self.unwrap_model(model)
|
unwrapped_model = self.unwrap_model(model)
|
||||||
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
|
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
|
||||||
@ -103,17 +99,11 @@ class DDPStrategy(Strategy):
|
|||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
tokenizer.save_pretrained(path)
|
tokenizer.save_pretrained(path)
|
||||||
model_path = os.path.join(path, "pytorch_model.bin")
|
model_path = os.path.join(path, "pytorch_model.bin")
|
||||||
self.save_model(model,
|
self.save_model(model, model_path, only_rank0=only_rank0)
|
||||||
model_path,
|
|
||||||
only_rank0=only_rank0)
|
|
||||||
|
|
||||||
def _replace_keys(model_path: str,
|
def _replace_keys(model_path: str, replace_fn: Callable):
|
||||||
replace_fn: Callable):
|
|
||||||
state_dict = torch.load(model_path, map_location="cpu")
|
state_dict = torch.load(model_path, map_location="cpu")
|
||||||
state_dict = {
|
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
|
||||||
replace_fn(k): v
|
|
||||||
for k, v in state_dict.items()
|
|
||||||
}
|
|
||||||
torch.save(state_dict, model_path)
|
torch.save(state_dict, model_path)
|
||||||
|
|
||||||
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
|
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
|
||||||
@ -124,13 +114,13 @@ class DDPStrategy(Strategy):
|
|||||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||||
# TODO: implement sharding on naive strategy
|
# TODO: implement sharding on naive strategy
|
||||||
model = self.unwrap_model(model)
|
model = self.unwrap_model(model)
|
||||||
if 'requires_grad_only' in config and config['requires_grad_only'] == True:
|
if "requires_grad_only" in config and config["requires_grad_only"] == True:
|
||||||
state_dict = get_grad_required_state_dict(model)
|
state_dict = get_grad_required_state_dict(model)
|
||||||
else:
|
else:
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
if 'shard_size' in config:
|
if "shard_size" in config:
|
||||||
shard_size = config['shard_size']
|
shard_size = config["shard_size"]
|
||||||
accumulate_size = 0
|
accumulate_size = 0
|
||||||
state_dict_shard = OrderedDict()
|
state_dict_shard = OrderedDict()
|
||||||
for name, param in state_dict.items():
|
for name, param in state_dict.items():
|
||||||
|
@ -4,7 +4,6 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class DistributedSampler:
|
class DistributedSampler:
|
||||||
|
|
||||||
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
|
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.num_replicas = num_replicas
|
self.num_replicas = num_replicas
|
||||||
@ -12,7 +11,7 @@ class DistributedSampler:
|
|||||||
|
|
||||||
if len(self.dataset) % self.num_replicas != 0:
|
if len(self.dataset) % self.num_replicas != 0:
|
||||||
self.num_samples = math.ceil(
|
self.num_samples = math.ceil(
|
||||||
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
||||||
@ -20,10 +19,10 @@ class DistributedSampler:
|
|||||||
self.total_size = self.num_samples * self.num_replicas
|
self.total_size = self.num_samples * self.num_replicas
|
||||||
|
|
||||||
indices = list(range(len(self.dataset)))
|
indices = list(range(len(self.dataset)))
|
||||||
indices = indices[:self.total_size]
|
indices = indices[: self.total_size]
|
||||||
assert len(indices) == self.total_size
|
assert len(indices) == self.total_size
|
||||||
# subsample
|
# subsample
|
||||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||||
assert len(indices) == self.num_samples
|
assert len(indices) == self.num_samples
|
||||||
self.indices = indices
|
self.indices = indices
|
||||||
|
|
||||||
|
@ -42,7 +42,6 @@ def is_rank_0() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def to_device(x: Any, device: torch.device) -> Any:
|
def to_device(x: Any, device: torch.device) -> Any:
|
||||||
|
|
||||||
def _to(t: Any):
|
def _to(t: Any):
|
||||||
if isinstance(t, torch.Tensor):
|
if isinstance(t, torch.Tensor):
|
||||||
return t.to(device)
|
return t.to(device)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
@ -9,7 +8,8 @@ from utils import jload
|
|||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
assert len(args.answer_file_list) == len(
|
assert len(args.answer_file_list) == len(
|
||||||
args.model_name_list), "The number of answer files and model names should be equal!"
|
args.model_name_list
|
||||||
|
), "The number of answer files and model names should be equal!"
|
||||||
|
|
||||||
# load config
|
# load config
|
||||||
config = jload(args.config_file)
|
config = jload(args.config_file)
|
||||||
@ -36,7 +36,8 @@ def main(args):
|
|||||||
|
|
||||||
if len(args.model_name_list) == 1 and not gpt_evaluation_prompt:
|
if len(args.model_name_list) == 1 and not gpt_evaluation_prompt:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!")
|
"No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!"
|
||||||
|
)
|
||||||
|
|
||||||
if args.gpt_model == "text-davinci-003" and args.gpt_with_reference:
|
if args.gpt_model == "text-davinci-003" and args.gpt_with_reference:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@ -44,8 +45,15 @@ def main(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# initialize evaluator
|
# initialize evaluator
|
||||||
evaluator = Evaluator(metrics_per_category, battle_prompt, gpt_evaluation_prompt, args.gpt_model,
|
evaluator = Evaluator(
|
||||||
config["language"], config.get("path_for_UniEval", None), args.gpt_with_reference)
|
metrics_per_category,
|
||||||
|
battle_prompt,
|
||||||
|
gpt_evaluation_prompt,
|
||||||
|
args.gpt_model,
|
||||||
|
config["language"],
|
||||||
|
config.get("path_for_UniEval", None),
|
||||||
|
args.gpt_with_reference,
|
||||||
|
)
|
||||||
if len(args.model_name_list) == 2:
|
if len(args.model_name_list) == 2:
|
||||||
answers1 = jload(args.answer_file_list[0])
|
answers1 = jload(args.answer_file_list[0])
|
||||||
answers2 = jload(args.answer_file_list[1])
|
answers2 = jload(args.answer_file_list[1])
|
||||||
@ -68,41 +76,41 @@ def main(args):
|
|||||||
raise ValueError(f'Unsupported language {config["language"]}!')
|
raise ValueError(f'Unsupported language {config["language"]}!')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='ColossalAI LLM evaluation pipeline.')
|
parser = argparse.ArgumentParser(description="ColossalAI LLM evaluation pipeline.")
|
||||||
parser.add_argument('--config_file',
|
parser.add_argument(
|
||||||
type=str,
|
"--config_file", type=str, default=None, required=True, help="path to the file of target results"
|
||||||
default=None,
|
)
|
||||||
required=True,
|
parser.add_argument("--battle_prompt_file", type=str, default=None, help="path to the prompt file for battle")
|
||||||
help='path to the file of target results')
|
parser.add_argument(
|
||||||
parser.add_argument('--battle_prompt_file', type=str, default=None, help='path to the prompt file for battle')
|
"--gpt_evaluation_prompt_file", type=str, default=None, help="path to the prompt file for gpt evaluation"
|
||||||
parser.add_argument('--gpt_evaluation_prompt_file',
|
)
|
||||||
type=str,
|
parser.add_argument("--target_file", type=str, default=None, help="path to the target answer (ground truth) file")
|
||||||
default=None,
|
parser.add_argument(
|
||||||
help='path to the prompt file for gpt evaluation')
|
"--answer_file_list",
|
||||||
parser.add_argument('--target_file', type=str, default=None, help='path to the target answer (ground truth) file')
|
type=str,
|
||||||
parser.add_argument('--answer_file_list',
|
nargs="+",
|
||||||
type=str,
|
default=[],
|
||||||
nargs='+',
|
required=True,
|
||||||
default=[],
|
help="path to the answer files of at most 2 models",
|
||||||
required=True,
|
)
|
||||||
help='path to the answer files of at most 2 models')
|
parser.add_argument(
|
||||||
parser.add_argument('--model_name_list',
|
"--model_name_list", type=str, nargs="+", default=[], required=True, help="the names of at most 2 models"
|
||||||
type=str,
|
)
|
||||||
nargs='+',
|
parser.add_argument(
|
||||||
default=[],
|
"--gpt_model",
|
||||||
required=True,
|
default="gpt-3.5-turbo",
|
||||||
help='the names of at most 2 models')
|
choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"],
|
||||||
parser.add_argument('--gpt_model',
|
help="which GPT model to use for evaluation",
|
||||||
default="gpt-3.5-turbo",
|
)
|
||||||
choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"],
|
parser.add_argument(
|
||||||
help='which GPT model to use for evaluation')
|
"--gpt_with_reference",
|
||||||
parser.add_argument('--gpt_with_reference',
|
default=False,
|
||||||
default=False,
|
action="store_true",
|
||||||
action="store_true",
|
help="whether to include reference answer in gpt evaluation",
|
||||||
help='whether to include reference answer in gpt evaluation')
|
)
|
||||||
parser.add_argument('--save_path', type=str, default="results", help='path to save evaluation results')
|
parser.add_argument("--save_path", type=str, default="results", help="path to save evaluation results")
|
||||||
parser.add_argument('--openai_key', type=str, default=None, required=True, help='Your openai key')
|
parser.add_argument("--openai_key", type=str, default=None, required=True, help="Your openai key")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.openai_key is not None:
|
if args.openai_key is not None:
|
||||||
|
@ -3,20 +3,27 @@ from typing import Any, Dict, List
|
|||||||
|
|
||||||
import gpt_evaluate
|
import gpt_evaluate
|
||||||
import metrics
|
import metrics
|
||||||
import pandas as pd
|
|
||||||
import unieval
|
import unieval
|
||||||
from utils import analyze_automatic_results, get_data_per_category, save_automatic_results
|
from utils import analyze_automatic_results, get_data_per_category, save_automatic_results
|
||||||
|
|
||||||
|
|
||||||
class Evaluator(object):
|
class Evaluator(object):
|
||||||
"""
|
"""
|
||||||
A class named Evaluator includes GPT-3.5/GPT-4 evaluation
|
A class named Evaluator includes GPT-3.5/GPT-4 evaluation
|
||||||
and automatic evaluation
|
and automatic evaluation
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params: Dict[str, Any], battle_prompt: Dict[str, Any], gpt_evaluation_prompt: Dict[str, Any],
|
def __init__(
|
||||||
gpt_model: str, language: str, path_for_UniEval: Dict[str, str], gpt_with_reference: bool) -> None:
|
self,
|
||||||
|
params: Dict[str, Any],
|
||||||
|
battle_prompt: Dict[str, Any],
|
||||||
|
gpt_evaluation_prompt: Dict[str, Any],
|
||||||
|
gpt_model: str,
|
||||||
|
language: str,
|
||||||
|
path_for_UniEval: Dict[str, str],
|
||||||
|
gpt_with_reference: bool,
|
||||||
|
) -> None:
|
||||||
self.params = params
|
self.params = params
|
||||||
self.battle_prompt = battle_prompt
|
self.battle_prompt = battle_prompt
|
||||||
self.gpt_evaluation_prompt = gpt_evaluation_prompt
|
self.gpt_evaluation_prompt = gpt_evaluation_prompt
|
||||||
@ -103,7 +110,8 @@ class Evaluator(object):
|
|||||||
|
|
||||||
if self.params[category]["UniEval"] and self.language == "cn":
|
if self.params[category]["UniEval"] and self.language == "cn":
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file.")
|
"UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file."
|
||||||
|
)
|
||||||
|
|
||||||
category_metrics = self.params[category]["UniEval"]
|
category_metrics = self.params[category]["UniEval"]
|
||||||
|
|
||||||
@ -134,10 +142,9 @@ class Evaluator(object):
|
|||||||
sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]]
|
sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]]
|
||||||
|
|
||||||
data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list)
|
data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list)
|
||||||
scores = uni_evaluator.evaluate(data,
|
scores = uni_evaluator.evaluate(
|
||||||
category,
|
data, category, dims=list(self.unieval_metric_stats[task][category].keys()), overall=False
|
||||||
dims=list(self.unieval_metric_stats[task][category].keys()),
|
)
|
||||||
overall=False)
|
|
||||||
avg_scores = unieval.calculate_average_score(scores)
|
avg_scores = unieval.calculate_average_score(scores)
|
||||||
|
|
||||||
self.unieval_metric_stats[task][category].update(avg_scores)
|
self.unieval_metric_stats[task][category].update(avg_scores)
|
||||||
@ -165,7 +172,8 @@ class Evaluator(object):
|
|||||||
category,
|
category,
|
||||||
self.gpt_model,
|
self.gpt_model,
|
||||||
self.language,
|
self.language,
|
||||||
references=targets_per_category[category] if self.gpt_with_reference else None)
|
references=targets_per_category[category] if self.gpt_with_reference else None,
|
||||||
|
)
|
||||||
|
|
||||||
def save(self, path: str, model_name_list: List[str]) -> None:
|
def save(self, path: str, model_name_list: List[str]) -> None:
|
||||||
"""
|
"""
|
||||||
@ -204,16 +212,18 @@ class Evaluator(object):
|
|||||||
gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
|
gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
|
||||||
gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
|
gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
|
||||||
|
|
||||||
all_evaluations = gpt_evaluate.save_gpt_evaluation_results(model_name_list[0],
|
all_evaluations = gpt_evaluate.save_gpt_evaluation_results(
|
||||||
self.gpt_evaluation_results,
|
model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path
|
||||||
gpt_evaluation_results_save_path)
|
)
|
||||||
|
|
||||||
# Start to calculate scores and save statistics.
|
# Start to calculate scores and save statistics.
|
||||||
gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
|
gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
|
||||||
gpt_evaluate.save_gpt_evaluation_statistics(model_name_list[0], all_evaluations,
|
gpt_evaluate.save_gpt_evaluation_statistics(
|
||||||
gpt_evaluation_statistics_save_path)
|
model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path
|
||||||
|
)
|
||||||
|
|
||||||
# Save charts and csv.
|
# Save charts and csv.
|
||||||
gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
|
gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
|
||||||
gpt_evaluate.analyze_gpt_evaluation_statistics(gpt_evaluation_statistics_save_path,
|
gpt_evaluate.analyze_gpt_evaluation_statistics(
|
||||||
gpt_evaluation_analyses_save_path)
|
gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path
|
||||||
|
)
|
||||||
|
@ -14,20 +14,18 @@ import tqdm
|
|||||||
from utils import jdump, jload
|
from utils import jdump, jload
|
||||||
|
|
||||||
ref_step_template = {
|
ref_step_template = {
|
||||||
"en":
|
"en": "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
|
||||||
"Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
|
"cn": "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n",
|
||||||
"cn":
|
|
||||||
"请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ref_answer_template_general = {
|
ref_answer_template_general = {
|
||||||
"en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n",
|
"en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n",
|
||||||
"cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n"
|
"cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n",
|
||||||
}
|
}
|
||||||
|
|
||||||
ref_answer_template_correctness = {
|
ref_answer_template_correctness = {
|
||||||
"en": "\nA correct answer is as follows:\n\n{answer}\n\n",
|
"en": "\nA correct answer is as follows:\n\n{answer}\n\n",
|
||||||
"cn": "\n标准答案如下:\n\n{answer}\n\n"
|
"cn": "\n标准答案如下:\n\n{answer}\n\n",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -51,10 +49,7 @@ def get_battle_result(sys_prompt: str, user_prompt: str, id: int, max_tokens: in
|
|||||||
response = openai.ChatCompletion.create(
|
response = openai.ChatCompletion.create(
|
||||||
model="gpt-4",
|
model="gpt-4",
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{"role": "system", "content": sys_prompt},
|
||||||
"role": "system",
|
|
||||||
"content": sys_prompt
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": user_prompt,
|
"content": user_prompt,
|
||||||
@ -106,7 +101,7 @@ def parse_battle_score(evaluation: str) -> List[float]:
|
|||||||
return [float(sp[0]), float(sp[1])]
|
return [float(sp[0]), float(sp[1])]
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid score pair. Got {evaluation}.")
|
raise Exception(f"Invalid score pair. Got {evaluation}.")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return [-1, -1]
|
return [-1, -1]
|
||||||
|
|
||||||
|
|
||||||
@ -125,9 +120,6 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]
|
|||||||
|
|
||||||
assert len(answer1) == len(answer2)
|
assert len(answer1) == len(answer2)
|
||||||
|
|
||||||
handles = []
|
|
||||||
evaluation_file = []
|
|
||||||
|
|
||||||
total_len = len(answer1)
|
total_len = len(answer1)
|
||||||
question_idx_list = list(range(total_len))
|
question_idx_list = list(range(total_len))
|
||||||
|
|
||||||
@ -140,9 +132,12 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]
|
|||||||
assert answer1[i]["id"] == answer2[i]["id"]
|
assert answer1[i]["id"] == answer2[i]["id"]
|
||||||
answer_id = answer1[i]["id"]
|
answer_id = answer1[i]["id"]
|
||||||
|
|
||||||
ques = answer1[i]["instruction"] if answer1[i][
|
ques = (
|
||||||
"input"] == "" else answer1[i]["instruction"] + " " + answer1[i]["input"]
|
answer1[i]["instruction"]
|
||||||
cat = answer1[i]["category"]
|
if answer1[i]["input"] == ""
|
||||||
|
else answer1[i]["instruction"] + " " + answer1[i]["input"]
|
||||||
|
)
|
||||||
|
answer1[i]["category"]
|
||||||
ans1 = answer1[i]["output"]
|
ans1 = answer1[i]["output"]
|
||||||
ans2 = answer2[i]["output"]
|
ans2 = answer2[i]["output"]
|
||||||
|
|
||||||
@ -267,7 +262,11 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) ->
|
|||||||
|
|
||||||
step_to_add = ref_step_template[language]
|
step_to_add = ref_step_template[language]
|
||||||
|
|
||||||
for_the_given_answer = "{metric} (1-5) (directly give the score for the given answer):" if language == "en" else "{metric} (1-5) (直接对给定答案打分)"
|
for_the_given_answer = (
|
||||||
|
"{metric} (1-5) (directly give the score for the given answer):"
|
||||||
|
if language == "en"
|
||||||
|
else "{metric} (1-5) (直接对给定答案打分)"
|
||||||
|
)
|
||||||
|
|
||||||
# adjective is used to describe the word "answer" in the prompt.
|
# adjective is used to describe the word "answer" in the prompt.
|
||||||
adjective = "example" if language == "en" else "示例"
|
adjective = "example" if language == "en" else "示例"
|
||||||
@ -280,8 +279,9 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) ->
|
|||||||
answer_to_add = ref_answer_template_correctness[language]
|
answer_to_add = ref_answer_template_correctness[language]
|
||||||
|
|
||||||
answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"])
|
answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"])
|
||||||
step_to_add = step_to_add.format(metric=metric.lower(),
|
step_to_add = step_to_add.format(metric=metric.lower(), adjective=adjective) + for_the_given_answer.format(
|
||||||
adjective=adjective) + for_the_given_answer.format(metric=metric)
|
metric=metric
|
||||||
|
)
|
||||||
|
|
||||||
return answer_to_add + step_to_add
|
return answer_to_add + step_to_add
|
||||||
|
|
||||||
@ -329,7 +329,8 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens:
|
|||||||
for j in range(i):
|
for j in range(i):
|
||||||
messages_to_send.append(fill_in_message("user", user_messages[j]))
|
messages_to_send.append(fill_in_message("user", user_messages[j]))
|
||||||
messages_to_send.append(
|
messages_to_send.append(
|
||||||
fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"]))
|
fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"])
|
||||||
|
)
|
||||||
|
|
||||||
# Length of user messages == Length of assistant messages + 1
|
# Length of user messages == Length of assistant messages + 1
|
||||||
# Because we always expect the api to response
|
# Because we always expect the api to response
|
||||||
@ -351,13 +352,15 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens:
|
|||||||
return assistant_responses[-1]
|
return assistant_responses[-1]
|
||||||
|
|
||||||
|
|
||||||
def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
|
def get_gpt_evaluation_without_logprobs(
|
||||||
inst: Dict[str, Any],
|
prompt: Dict[str, Any],
|
||||||
metrics: List[str],
|
inst: Dict[str, Any],
|
||||||
language: str,
|
metrics: List[str],
|
||||||
reference: Dict[str, Any] = None,
|
language: str,
|
||||||
model: str = "gpt-3.5-turbo",
|
reference: Dict[str, Any] = None,
|
||||||
max_tokens: int = 2048) -> Dict[str, Any]:
|
model: str = "gpt-3.5-turbo",
|
||||||
|
max_tokens: int = 2048,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer.
|
Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer.
|
||||||
|
|
||||||
@ -378,7 +381,7 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
|
|||||||
|
|
||||||
MAX_API_RETRY = 3
|
MAX_API_RETRY = 3
|
||||||
|
|
||||||
question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"])
|
question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
|
||||||
answer = inst["output"]
|
answer = inst["output"]
|
||||||
inst["evaluation"] = {}
|
inst["evaluation"] = {}
|
||||||
|
|
||||||
@ -400,10 +403,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
|
|||||||
|
|
||||||
if prompt_reference:
|
if prompt_reference:
|
||||||
# Do a 2-round conversation
|
# Do a 2-round conversation
|
||||||
response = multiturn_chat_completion([prompt_1st_round, prompt_reference],
|
response = multiturn_chat_completion(
|
||||||
model,
|
[prompt_1st_round, prompt_reference], model, max_tokens=max_tokens, turns=2
|
||||||
max_tokens=max_tokens,
|
)
|
||||||
turns=2)
|
|
||||||
else:
|
else:
|
||||||
response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1)
|
response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1)
|
||||||
|
|
||||||
@ -427,10 +429,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
|
|||||||
return inst
|
return inst
|
||||||
|
|
||||||
|
|
||||||
def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
|
def get_gpt_evaluation_with_logprobs(
|
||||||
inst: Dict[str, Any],
|
prompt: Dict[str, Any], inst: Dict[str, Any], metrics: List[str], max_tokens: int = 2048
|
||||||
metrics: List[str],
|
) -> Dict[str, Any]:
|
||||||
max_tokens: int = 2048) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
Use completion model(text-davinci-003) to evaluate one model answer.
|
Use completion model(text-davinci-003) to evaluate one model answer.
|
||||||
Only completion models can return log probabilities.
|
Only completion models can return log probabilities.
|
||||||
@ -449,7 +450,7 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
|
|||||||
|
|
||||||
MAX_API_RETRY = 3
|
MAX_API_RETRY = 3
|
||||||
|
|
||||||
question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"])
|
question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
|
||||||
answer = inst["output"]
|
answer = inst["output"]
|
||||||
inst["evaluation"] = {}
|
inst["evaluation"] = {}
|
||||||
|
|
||||||
@ -492,13 +493,15 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
|
|||||||
return inst
|
return inst
|
||||||
|
|
||||||
|
|
||||||
def evaluate(answers: List[Dict],
|
def evaluate(
|
||||||
prompt: Dict[str, Any],
|
answers: List[Dict],
|
||||||
metrics: List[str],
|
prompt: Dict[str, Any],
|
||||||
category: str,
|
metrics: List[str],
|
||||||
model: str,
|
category: str,
|
||||||
language: str,
|
model: str,
|
||||||
references: List[Dict] = None) -> List[Dict]:
|
language: str,
|
||||||
|
references: List[Dict] = None,
|
||||||
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Use GPT models to evaluate model answers and save evaluation results.
|
Use GPT models to evaluate model answers and save evaluation results.
|
||||||
|
|
||||||
@ -529,21 +532,23 @@ def evaluate(answers: List[Dict],
|
|||||||
if model == "text-davinci-003":
|
if model == "text-davinci-003":
|
||||||
future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1)
|
future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1)
|
||||||
else:
|
else:
|
||||||
future = executor.submit(get_gpt_evaluation_without_logprobs,
|
future = executor.submit(
|
||||||
prompt,
|
get_gpt_evaluation_without_logprobs,
|
||||||
inst,
|
prompt,
|
||||||
metrics,
|
inst,
|
||||||
language,
|
metrics,
|
||||||
reference=None if references is None else references[idx],
|
language,
|
||||||
model=model,
|
reference=None if references is None else references[idx],
|
||||||
max_tokens=1)
|
model=model,
|
||||||
|
max_tokens=1,
|
||||||
|
)
|
||||||
|
|
||||||
futures.append(future)
|
futures.append(future)
|
||||||
|
|
||||||
for future in tqdm.tqdm(
|
for future in tqdm.tqdm(
|
||||||
concurrent.futures.as_completed(futures),
|
concurrent.futures.as_completed(futures),
|
||||||
desc=f"{category}: ",
|
desc=f"{category}: ",
|
||||||
total=len(futures),
|
total=len(futures),
|
||||||
):
|
):
|
||||||
evaluations.append(future.result())
|
evaluations.append(future.result())
|
||||||
|
|
||||||
@ -610,12 +615,13 @@ def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) ->
|
|||||||
return int(results[0])
|
return int(results[0])
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid score pair. Got {evaluation}.")
|
raise Exception(f"Invalid score pair. Got {evaluation}.")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def save_gpt_evaluation_results(model_name: str, gpt_evaluation_results: Dict[str, Any],
|
def save_gpt_evaluation_results(
|
||||||
save_path: str) -> Dict[str, Any]:
|
model_name: str, gpt_evaluation_results: Dict[str, Any], save_path: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Save evaluation results for different categories for one model.
|
Save evaluation results for different categories for one model.
|
||||||
|
|
||||||
@ -667,10 +673,12 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav
|
|||||||
scores[metric].append(0)
|
scores[metric].append(0)
|
||||||
elif evaluation["evaluation"][metric]["logprobs"] is not None:
|
elif evaluation["evaluation"][metric]["logprobs"] is not None:
|
||||||
scores[metric].append(
|
scores[metric].append(
|
||||||
calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0]))
|
calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
scores[metric].append(
|
scores[metric].append(
|
||||||
calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation))
|
calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation)
|
||||||
|
)
|
||||||
|
|
||||||
statistics = {}
|
statistics = {}
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
@ -751,9 +759,9 @@ def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> N
|
|||||||
frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv"))
|
frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv"))
|
||||||
|
|
||||||
for category in tqdm.tqdm(
|
for category in tqdm.tqdm(
|
||||||
frame_per_category.keys(),
|
frame_per_category.keys(),
|
||||||
desc=f"GPT evaluation: ",
|
desc=f"GPT evaluation: ",
|
||||||
total=len(frame_per_category.keys()),
|
total=len(frame_per_category.keys()),
|
||||||
):
|
):
|
||||||
data = pd.DataFrame(frame_per_category[category])
|
data = pd.DataFrame(frame_per_category[category])
|
||||||
|
|
||||||
|
@ -21,13 +21,17 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str,
|
|||||||
"""
|
"""
|
||||||
bleu_scores = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0}
|
bleu_scores = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0}
|
||||||
cumulative_bleu = [0] * 4
|
cumulative_bleu = [0] * 4
|
||||||
weights = [(1. / 1., 0., 0., 0.), (1. / 2., 1. / 2., 0., 0.), (1. / 3., 1. / 3., 1. / 3., 0.),
|
weights = [
|
||||||
(1. / 4., 1. / 4., 1. / 4., 1. / 4.)]
|
(1.0 / 1.0, 0.0, 0.0, 0.0),
|
||||||
|
(1.0 / 2.0, 1.0 / 2.0, 0.0, 0.0),
|
||||||
|
(1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, 0.0),
|
||||||
|
(1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0),
|
||||||
|
]
|
||||||
|
|
||||||
for pred, target in zip(preds, targets):
|
for pred, target in zip(preds, targets):
|
||||||
if language == "cn":
|
if language == "cn":
|
||||||
pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split()
|
pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split()
|
||||||
target_list = [(' '.join(jieba.cut(preprocessing_text(target)))).split()]
|
target_list = [(" ".join(jieba.cut(preprocessing_text(target)))).split()]
|
||||||
elif language == "en":
|
elif language == "en":
|
||||||
pred_list = preprocessing_text(pred).split()
|
pred_list = preprocessing_text(pred).split()
|
||||||
target_list = [preprocessing_text(target).split()]
|
target_list = [preprocessing_text(target).split()]
|
||||||
@ -42,15 +46,14 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str,
|
|||||||
|
|
||||||
|
|
||||||
def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
|
def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
|
||||||
"""Calculate CHRF Score Metric in sentence level.
|
"""Calculate CHRF Score Metric in sentence level."""
|
||||||
"""
|
|
||||||
chrf_score = {"chrf": 0}
|
chrf_score = {"chrf": 0}
|
||||||
cumulative_chrf = []
|
cumulative_chrf = []
|
||||||
|
|
||||||
for pred, target in zip(preds, targets):
|
for pred, target in zip(preds, targets):
|
||||||
if language == "cn":
|
if language == "cn":
|
||||||
pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split()
|
pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split()
|
||||||
target_list = ' '.join(jieba.cut(preprocessing_text(target))).split()
|
target_list = " ".join(jieba.cut(preprocessing_text(target))).split()
|
||||||
elif language == "en":
|
elif language == "en":
|
||||||
pred_list = preprocessing_text(pred).split()
|
pred_list = preprocessing_text(pred).split()
|
||||||
target_list = preprocessing_text(target).split()
|
target_list = preprocessing_text(target).split()
|
||||||
@ -75,8 +78,8 @@ def rouge_cn_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
|
|||||||
all_targets = []
|
all_targets = []
|
||||||
|
|
||||||
for pred, target in zip(preds, targets):
|
for pred, target in zip(preds, targets):
|
||||||
pred_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(pred))))
|
pred_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(pred))))
|
||||||
target_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(target))))
|
target_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(target))))
|
||||||
all_preds.append(pred_list)
|
all_preds.append(pred_list)
|
||||||
all_targets.append(target_list)
|
all_targets.append(target_list)
|
||||||
|
|
||||||
@ -99,16 +102,14 @@ def rouge_en_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
|
|||||||
longest common subsequence (LCS) between preds and targets.
|
longest common subsequence (LCS) between preds and targets.
|
||||||
"""
|
"""
|
||||||
rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0}
|
rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0}
|
||||||
all_preds = []
|
|
||||||
all_targets = []
|
|
||||||
|
|
||||||
rouge_en = Rouge_en.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=False)
|
rouge_en = Rouge_en.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=False)
|
||||||
|
|
||||||
for pred, target in zip(preds, targets):
|
for pred, target in zip(preds, targets):
|
||||||
score = rouge_en.score(preprocessing_text(pred), preprocessing_text(target))
|
score = rouge_en.score(preprocessing_text(pred), preprocessing_text(target))
|
||||||
rouge_scores["rouge1"] += score['rouge1'].fmeasure
|
rouge_scores["rouge1"] += score["rouge1"].fmeasure
|
||||||
rouge_scores["rouge2"] += score['rouge2'].fmeasure
|
rouge_scores["rouge2"] += score["rouge2"].fmeasure
|
||||||
rouge_scores["rougeL"] += score['rougeL'].fmeasure
|
rouge_scores["rougeL"] += score["rougeL"].fmeasure
|
||||||
|
|
||||||
rouge_scores["rouge1"] = rouge_scores["rouge1"] / len(preds)
|
rouge_scores["rouge1"] = rouge_scores["rouge1"] / len(preds)
|
||||||
rouge_scores["rouge2"] = rouge_scores["rouge2"] / len(preds)
|
rouge_scores["rouge2"] = rouge_scores["rouge2"] / len(preds)
|
||||||
@ -137,7 +138,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]:
|
|||||||
|
|
||||||
for pred in preds:
|
for pred in preds:
|
||||||
if language == "cn":
|
if language == "cn":
|
||||||
pred_seg_list = ' '.join(jieba.cut(pred)).split()
|
pred_seg_list = " ".join(jieba.cut(pred)).split()
|
||||||
count_segs = len(pred_seg_list)
|
count_segs = len(pred_seg_list)
|
||||||
unique_segs = set(pred_seg_list)
|
unique_segs = set(pred_seg_list)
|
||||||
count_unique_chars = len(unique_segs)
|
count_unique_chars = len(unique_segs)
|
||||||
@ -151,7 +152,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]:
|
|||||||
split_pred = preprocessing_text(pred).split()
|
split_pred = preprocessing_text(pred).split()
|
||||||
for n in range(0, 3):
|
for n in range(0, 3):
|
||||||
for i in range(0, len(split_pred) - n):
|
for i in range(0, len(split_pred) - n):
|
||||||
ngram = ' '.join(split_pred[i:i + n + 1])
|
ngram = " ".join(split_pred[i : i + n + 1])
|
||||||
unique_ngram[n].add(ngram)
|
unique_ngram[n].add(ngram)
|
||||||
all_ngram_count[n] += 1
|
all_ngram_count[n] += 1
|
||||||
|
|
||||||
@ -203,8 +204,8 @@ def calculate_precision_recall_f1(preds: List[str], targets: List[str], language
|
|||||||
|
|
||||||
for pred, target in zip(preds, targets):
|
for pred, target in zip(preds, targets):
|
||||||
if language == "cn":
|
if language == "cn":
|
||||||
pred_list = [char for char in ' '.join(jieba.cut(preprocessing_text(pred))).split()]
|
pred_list = [char for char in " ".join(jieba.cut(preprocessing_text(pred))).split()]
|
||||||
target_list = [char for char in ' '.join(jieba.cut(preprocessing_text(target))).split()]
|
target_list = [char for char in " ".join(jieba.cut(preprocessing_text(target))).split()]
|
||||||
elif language == "en":
|
elif language == "en":
|
||||||
pred_list = [char for char in preprocessing_text(pred).split()]
|
pred_list = [char for char in preprocessing_text(pred).split()]
|
||||||
target_list = [char for char in preprocessing_text(target).split()]
|
target_list = [char for char in preprocessing_text(target).split()]
|
||||||
|
@ -7,6 +7,9 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_evaluator', 'convert_data_to_unieval_format', 'calculate_average_score', 'save_unieval_results',
|
"get_evaluator",
|
||||||
'analyze_unieval_results'
|
"convert_data_to_unieval_format",
|
||||||
|
"calculate_average_score",
|
||||||
|
"save_unieval_results",
|
||||||
|
"analyze_unieval_results",
|
||||||
]
|
]
|
||||||
|
@ -28,29 +28,29 @@ from .utils import add_question
|
|||||||
|
|
||||||
|
|
||||||
class SumEvaluator:
|
class SumEvaluator:
|
||||||
|
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
|
||||||
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
|
"""Set up evaluator for text summarization"""
|
||||||
""" Set up evaluator for text summarization """
|
|
||||||
self.scorer = UniEvaluator(
|
self.scorer = UniEvaluator(
|
||||||
model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
|
model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
device=device,
|
device=device,
|
||||||
cache_dir=cache_dir)
|
cache_dir=cache_dir,
|
||||||
self.task = 'summarization'
|
)
|
||||||
self.dimensions = ['coherence', 'consistency', 'fluency', 'relevance']
|
self.task = "summarization"
|
||||||
|
self.dimensions = ["coherence", "consistency", "fluency", "relevance"]
|
||||||
|
|
||||||
def evaluate(self, data, category, dims=None, overall=True):
|
def evaluate(self, data, category, dims=None, overall=True):
|
||||||
"""
|
"""
|
||||||
Get the scores of all the given dimensions
|
Get the scores of all the given dimensions
|
||||||
|
|
||||||
category: The category to be evaluated.
|
category: The category to be evaluated.
|
||||||
|
|
||||||
dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate
|
dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate
|
||||||
four dimensions: coherence, consistency, fluency, relevance.
|
four dimensions: coherence, consistency, fluency, relevance.
|
||||||
|
|
||||||
overall: indicates whether the overall score is to be calculated.
|
overall: indicates whether the overall score is to be calculated.
|
||||||
Overall score can be customized to a combination of scores based on different
|
Overall score can be customized to a combination of scores based on different
|
||||||
dimensions. The default here is the average score of all the given dimensions.
|
dimensions. The default here is the average score of all the given dimensions.
|
||||||
"""
|
"""
|
||||||
n_data = len(data)
|
n_data = len(data)
|
||||||
eval_scores = [{} for _ in range(n_data)]
|
eval_scores = [{} for _ in range(n_data)]
|
||||||
@ -63,12 +63,12 @@ class SumEvaluator:
|
|||||||
|
|
||||||
for dim in eval_dims:
|
for dim in eval_dims:
|
||||||
# Calculate average sentence-level scores for 'consistency' and 'fluency'
|
# Calculate average sentence-level scores for 'consistency' and 'fluency'
|
||||||
if dim == 'consistency' or dim == 'fluency':
|
if dim == "consistency" or dim == "fluency":
|
||||||
src_list, output_list = [], []
|
src_list, output_list = [], []
|
||||||
n_sents = [] # the number of sentences in each generated summary
|
n_sents = [] # the number of sentences in each generated summary
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
source = data[i]['source']
|
source = data[i]["source"]
|
||||||
system_outputs = sent_tokenize(data[i]['system_output'])
|
system_outputs = sent_tokenize(data[i]["system_output"])
|
||||||
n_sents.append(len(system_outputs))
|
n_sents.append(len(system_outputs))
|
||||||
for j in range(len(system_outputs)):
|
for j in range(len(system_outputs)):
|
||||||
src_list.append(source)
|
src_list.append(source)
|
||||||
@ -81,24 +81,26 @@ class SumEvaluator:
|
|||||||
score = []
|
score = []
|
||||||
for cur_n_sent in n_sents:
|
for cur_n_sent in n_sents:
|
||||||
# prevent denominator from being 0
|
# prevent denominator from being 0
|
||||||
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / (cur_n_sent + 1e-6))
|
score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / (cur_n_sent + 1e-6))
|
||||||
start_idx += cur_n_sent
|
start_idx += cur_n_sent
|
||||||
|
|
||||||
# Calculate summary-level score for 'coherence' and 'relevance'
|
# Calculate summary-level score for 'coherence' and 'relevance'
|
||||||
elif dim == 'coherence' or dim == 'relevance':
|
elif dim == "coherence" or dim == "relevance":
|
||||||
src_list, output_list, ref_list = [], [], []
|
src_list, output_list, ref_list = [], [], []
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
src_list.append(data[i]['source'])
|
src_list.append(data[i]["source"])
|
||||||
output_list.append(data[i]['system_output'])
|
output_list.append(data[i]["system_output"])
|
||||||
if dim == 'relevance':
|
if dim == "relevance":
|
||||||
ref_list.append(data[i]['reference'])
|
ref_list.append(data[i]["reference"])
|
||||||
input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task)
|
input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task)
|
||||||
score = self.scorer.score(input_list, self.task, category, dim)
|
score = self.scorer.score(input_list, self.task, category, dim)
|
||||||
|
|
||||||
# Please customize other dimensions here for summarization
|
# Please customize other dimensions here for summarization
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('The input format for this dimension is still undefined. \
|
raise NotImplementedError(
|
||||||
Please customize it first.')
|
"The input format for this dimension is still undefined. \
|
||||||
|
Please customize it first."
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
eval_scores[i][dim] = score[i]
|
eval_scores[i][dim] = score[i]
|
||||||
@ -106,35 +108,35 @@ class SumEvaluator:
|
|||||||
# Customize your overall score here.
|
# Customize your overall score here.
|
||||||
if overall == True:
|
if overall == True:
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
|
eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
|
||||||
|
|
||||||
return eval_scores
|
return eval_scores
|
||||||
|
|
||||||
|
|
||||||
class DialogEvaluator:
|
class DialogEvaluator:
|
||||||
|
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
|
||||||
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
|
"""Set up evaluator for dialogues"""
|
||||||
""" Set up evaluator for dialogues """
|
|
||||||
self.scorer = UniEvaluator(
|
self.scorer = UniEvaluator(
|
||||||
model_name_or_path='MingZhong/unieval-dialog' if model_name_or_path == "" else model_name_or_path,
|
model_name_or_path="MingZhong/unieval-dialog" if model_name_or_path == "" else model_name_or_path,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
device=device,
|
device=device,
|
||||||
cache_dir=cache_dir)
|
cache_dir=cache_dir,
|
||||||
self.task = 'dialogue'
|
)
|
||||||
self.dimensions = ['naturalness', 'coherence', 'engagingness', 'groundedness', 'understandability']
|
self.task = "dialogue"
|
||||||
|
self.dimensions = ["naturalness", "coherence", "engagingness", "groundedness", "understandability"]
|
||||||
|
|
||||||
def evaluate(self, data, category, dims=None, overall=True):
|
def evaluate(self, data, category, dims=None, overall=True):
|
||||||
"""
|
"""
|
||||||
Get the scores of all the given dimensions
|
Get the scores of all the given dimensions
|
||||||
|
|
||||||
category: The category to be evaluated.
|
category: The category to be evaluated.
|
||||||
|
|
||||||
dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate
|
dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate
|
||||||
five dimensions: naturalness, coherence, engagingness, groundedness and understandability.
|
five dimensions: naturalness, coherence, engagingness, groundedness and understandability.
|
||||||
|
|
||||||
overall: indicates whether the overall score is to be calculated.
|
overall: indicates whether the overall score is to be calculated.
|
||||||
Overall score can be customized to a combination of scores based on different
|
Overall score can be customized to a combination of scores based on different
|
||||||
dimensions. The default here is the average score of all the given dimensions.
|
dimensions. The default here is the average score of all the given dimensions.
|
||||||
"""
|
"""
|
||||||
n_data = len(data)
|
n_data = len(data)
|
||||||
eval_scores = [{} for _ in range(n_data)]
|
eval_scores = [{} for _ in range(n_data)]
|
||||||
@ -147,50 +149,48 @@ class DialogEvaluator:
|
|||||||
|
|
||||||
for dim in eval_dims:
|
for dim in eval_dims:
|
||||||
# Calculate summation score for 'engagingness'
|
# Calculate summation score for 'engagingness'
|
||||||
if dim == 'engagingness':
|
if dim == "engagingness":
|
||||||
src_list, output_list, context_list = [], [], []
|
src_list, output_list, context_list = [], [], []
|
||||||
n_sents = [] # the number of sentences in each generated response
|
n_sents = [] # the number of sentences in each generated response
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
source = data[i]['source']
|
source = data[i]["source"]
|
||||||
context = data[i]['context']
|
context = data[i]["context"]
|
||||||
system_outputs = sent_tokenize(data[i]['system_output'])
|
system_outputs = sent_tokenize(data[i]["system_output"])
|
||||||
n_sents.append(len(system_outputs))
|
n_sents.append(len(system_outputs))
|
||||||
for j in range(len(system_outputs)):
|
for j in range(len(system_outputs)):
|
||||||
src_list.append(source)
|
src_list.append(source)
|
||||||
context_list.append(context)
|
context_list.append(context)
|
||||||
output_list.append(system_outputs[j])
|
output_list.append(system_outputs[j])
|
||||||
input_list = add_question(dimension=dim,
|
input_list = add_question(
|
||||||
output=output_list,
|
dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task
|
||||||
src=src_list,
|
)
|
||||||
context=context_list,
|
|
||||||
task=self.task)
|
|
||||||
sent_score = self.scorer.score(input_list, self.task, category, dim)
|
sent_score = self.scorer.score(input_list, self.task, category, dim)
|
||||||
|
|
||||||
# Get the summation score for each sample
|
# Get the summation score for each sample
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
score = []
|
score = []
|
||||||
for cur_n_sent in n_sents:
|
for cur_n_sent in n_sents:
|
||||||
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]))
|
score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]))
|
||||||
start_idx += cur_n_sent
|
start_idx += cur_n_sent
|
||||||
|
|
||||||
# Calculate turn-level score for other dimensions
|
# Calculate turn-level score for other dimensions
|
||||||
elif dim in ['naturalness', 'coherence', 'groundedness', 'understandability']:
|
elif dim in ["naturalness", "coherence", "groundedness", "understandability"]:
|
||||||
src_list, output_list, context_list = [], [], []
|
src_list, output_list, context_list = [], [], []
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
src_list.append(data[i]['source'])
|
src_list.append(data[i]["source"])
|
||||||
output_list.append(data[i]['system_output'])
|
output_list.append(data[i]["system_output"])
|
||||||
context_list.append(data[i]['context'])
|
context_list.append(data[i]["context"])
|
||||||
input_list = add_question(dimension=dim,
|
input_list = add_question(
|
||||||
output=output_list,
|
dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task
|
||||||
src=src_list,
|
)
|
||||||
context=context_list,
|
|
||||||
task=self.task)
|
|
||||||
score = self.scorer.score(input_list, self.task, category, dim)
|
score = self.scorer.score(input_list, self.task, category, dim)
|
||||||
|
|
||||||
# Please customize other dimensions here for summarization
|
# Please customize other dimensions here for summarization
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('The input format for this dimension is still undefined. \
|
raise NotImplementedError(
|
||||||
Please customize it first.')
|
"The input format for this dimension is still undefined. \
|
||||||
|
Please customize it first."
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
eval_scores[i][dim] = score[i]
|
eval_scores[i][dim] = score[i]
|
||||||
@ -198,35 +198,35 @@ class DialogEvaluator:
|
|||||||
# Customize your overall score here.
|
# Customize your overall score here.
|
||||||
if overall == True:
|
if overall == True:
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
|
eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
|
||||||
|
|
||||||
return eval_scores
|
return eval_scores
|
||||||
|
|
||||||
|
|
||||||
class D2tEvaluator:
|
class D2tEvaluator:
|
||||||
|
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
|
||||||
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
|
"""Set up evaluator for data-to-text"""
|
||||||
""" Set up evaluator for data-to-text """
|
|
||||||
self.scorer = UniEvaluator(
|
self.scorer = UniEvaluator(
|
||||||
model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
|
model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
device=device,
|
device=device,
|
||||||
cache_dir=cache_dir)
|
cache_dir=cache_dir,
|
||||||
self.task = 'data2text'
|
)
|
||||||
self.dimensions = ['naturalness', 'informativeness']
|
self.task = "data2text"
|
||||||
|
self.dimensions = ["naturalness", "informativeness"]
|
||||||
|
|
||||||
def evaluate(self, data, category, dims=None, overall=True):
|
def evaluate(self, data, category, dims=None, overall=True):
|
||||||
"""
|
"""
|
||||||
Get the scores of all the given dimensions
|
Get the scores of all the given dimensions
|
||||||
|
|
||||||
category: The category to be evaluated.
|
category: The category to be evaluated.
|
||||||
|
|
||||||
dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate
|
dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate
|
||||||
two dimensions: naturalness and informativeness.
|
two dimensions: naturalness and informativeness.
|
||||||
|
|
||||||
overall: indicates whether the overall score is to be calculated.
|
overall: indicates whether the overall score is to be calculated.
|
||||||
Overall score can be customized to a combination of scores based on different
|
Overall score can be customized to a combination of scores based on different
|
||||||
dimensions. The default here is the average score of all the given dimensions.
|
dimensions. The default here is the average score of all the given dimensions.
|
||||||
"""
|
"""
|
||||||
n_data = len(data)
|
n_data = len(data)
|
||||||
eval_scores = [{} for _ in range(n_data)]
|
eval_scores = [{} for _ in range(n_data)]
|
||||||
@ -240,8 +240,8 @@ class D2tEvaluator:
|
|||||||
for dim in eval_dims:
|
for dim in eval_dims:
|
||||||
output_list, ref_list = [], []
|
output_list, ref_list = [], []
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
output_list.append(data[i]['system_output'])
|
output_list.append(data[i]["system_output"])
|
||||||
ref_list.append(data[i]['reference'])
|
ref_list.append(data[i]["reference"])
|
||||||
|
|
||||||
input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task)
|
input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task)
|
||||||
score = self.scorer.score(input_list, self.task, category, dim)
|
score = self.scorer.score(input_list, self.task, category, dim)
|
||||||
@ -252,38 +252,38 @@ class D2tEvaluator:
|
|||||||
# Customize your overall score here.
|
# Customize your overall score here.
|
||||||
if overall == True:
|
if overall == True:
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
|
eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
|
||||||
|
|
||||||
return eval_scores
|
return eval_scores
|
||||||
|
|
||||||
|
|
||||||
class FactEvaluator:
|
class FactEvaluator:
|
||||||
|
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
|
||||||
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
|
"""Set up evaluator for factual consistency detection"""
|
||||||
""" Set up evaluator for factual consistency detection """
|
|
||||||
self.scorer = UniEvaluator(
|
self.scorer = UniEvaluator(
|
||||||
model_name_or_path='MingZhong/unieval-fact' if model_name_or_path == "" else model_name_or_path,
|
model_name_or_path="MingZhong/unieval-fact" if model_name_or_path == "" else model_name_or_path,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
device=device,
|
device=device,
|
||||||
cache_dir=cache_dir)
|
cache_dir=cache_dir,
|
||||||
self.task = 'fact'
|
)
|
||||||
self.dim = 'consistency'
|
self.task = "fact"
|
||||||
|
self.dim = "consistency"
|
||||||
|
|
||||||
def evaluate(self, data, category):
|
def evaluate(self, data, category):
|
||||||
"""
|
"""
|
||||||
Get the factual consistency score (only 1 dimension for this task)
|
Get the factual consistency score (only 1 dimension for this task)
|
||||||
|
|
||||||
category: The category to be evaluated.
|
category: The category to be evaluated.
|
||||||
"""
|
"""
|
||||||
n_data = len(data)
|
n_data = len(data)
|
||||||
eval_scores = [{} for _ in range(n_data)]
|
eval_scores = [{} for _ in range(n_data)]
|
||||||
|
|
||||||
# Calculate average sentence-level scores for factual consistency
|
# Calculate average sentence-level scores for factual consistency
|
||||||
src_list, output_list = [], []
|
src_list, output_list = [], []
|
||||||
n_sents = [] # the number of sentences in the claim
|
n_sents = [] # the number of sentences in the claim
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
source = data[i]['source']
|
source = data[i]["source"]
|
||||||
system_outputs = sent_tokenize(data[i]['system_output'])
|
system_outputs = sent_tokenize(data[i]["system_output"])
|
||||||
n_sents.append(len(system_outputs))
|
n_sents.append(len(system_outputs))
|
||||||
for j in range(len(system_outputs)):
|
for j in range(len(system_outputs)):
|
||||||
src_list.append(source)
|
src_list.append(source)
|
||||||
@ -295,7 +295,7 @@ class FactEvaluator:
|
|||||||
start_idx = 0
|
start_idx = 0
|
||||||
score = []
|
score = []
|
||||||
for cur_n_sent in n_sents:
|
for cur_n_sent in n_sents:
|
||||||
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent)
|
score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / cur_n_sent)
|
||||||
start_idx += cur_n_sent
|
start_idx += cur_n_sent
|
||||||
|
|
||||||
for i in range(n_data):
|
for i in range(n_data):
|
||||||
@ -304,28 +304,26 @@ class FactEvaluator:
|
|||||||
return eval_scores
|
return eval_scores
|
||||||
|
|
||||||
|
|
||||||
def get_evaluator(task, model_name_or_path="", max_length=1024, device='cuda:0', cache_dir=None):
|
def get_evaluator(task, model_name_or_path="", max_length=1024, device="cuda:0", cache_dir=None):
|
||||||
assert task in ['summarization', 'dialogue', 'data2text', 'fact']
|
assert task in ["summarization", "dialogue", "data2text", "fact"]
|
||||||
if task == 'summarization':
|
if task == "summarization":
|
||||||
return SumEvaluator(model_name_or_path=model_name_or_path,
|
return SumEvaluator(
|
||||||
max_length=max_length,
|
model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
|
||||||
device=device,
|
)
|
||||||
cache_dir=cache_dir)
|
elif task == "dialogue":
|
||||||
elif task == 'dialogue':
|
return DialogEvaluator(
|
||||||
return DialogEvaluator(model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
|
||||||
max_length=max_length,
|
)
|
||||||
device=device,
|
elif task == "data2text":
|
||||||
cache_dir=cache_dir)
|
return D2tEvaluator(
|
||||||
elif task == 'data2text':
|
model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
|
||||||
return D2tEvaluator(model_name_or_path=model_name_or_path,
|
)
|
||||||
max_length=max_length,
|
elif task == "fact":
|
||||||
device=device,
|
return FactEvaluator(
|
||||||
cache_dir=cache_dir)
|
model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
|
||||||
elif task == 'fact':
|
)
|
||||||
return FactEvaluator(model_name_or_path=model_name_or_path,
|
|
||||||
max_length=max_length,
|
|
||||||
device=device,
|
|
||||||
cache_dir=cache_dir)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Other tasks are not implemented, \
|
raise NotImplementedError(
|
||||||
please customize specific tasks here.')
|
"Other tasks are not implemented, \
|
||||||
|
please customize specific tasks here."
|
||||||
|
)
|
||||||
|
@ -27,9 +27,8 @@ from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
|
|||||||
|
|
||||||
|
|
||||||
class UniEvaluator:
|
class UniEvaluator:
|
||||||
|
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
|
||||||
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
|
"""Set up model"""
|
||||||
""" Set up model """
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
|
||||||
@ -47,8 +46,8 @@ class UniEvaluator:
|
|||||||
|
|
||||||
def score(self, inputs, task, category, dim, batch_size=8):
|
def score(self, inputs, task, category, dim, batch_size=8):
|
||||||
"""
|
"""
|
||||||
Get scores for the given samples.
|
Get scores for the given samples.
|
||||||
final_score = postive_score / (postive_score + negative_score)
|
final_score = postive_score / (postive_score + negative_score)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# The implementation of "forward" in T5 still requires decoder_input_ids.
|
# The implementation of "forward" in T5 still requires decoder_input_ids.
|
||||||
@ -58,31 +57,27 @@ class UniEvaluator:
|
|||||||
|
|
||||||
pos_score_list, neg_score_list = [], []
|
pos_score_list, neg_score_list = [], []
|
||||||
for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "):
|
for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "):
|
||||||
src_list = inputs[i:i + batch_size]
|
src_list = inputs[i : i + batch_size]
|
||||||
tgt_list = tgts[i:i + batch_size]
|
tgt_list = tgts[i : i + batch_size]
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
encoded_src = self.tokenizer(src_list,
|
encoded_src = self.tokenizer(
|
||||||
max_length=self.max_length,
|
src_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
|
||||||
truncation=True,
|
)
|
||||||
padding=True,
|
encoded_tgt = self.tokenizer(
|
||||||
return_tensors='pt')
|
tgt_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
|
||||||
encoded_tgt = self.tokenizer(tgt_list,
|
)
|
||||||
max_length=self.max_length,
|
|
||||||
truncation=True,
|
|
||||||
padding=True,
|
|
||||||
return_tensors='pt')
|
|
||||||
|
|
||||||
src_tokens = encoded_src['input_ids'].to(self.device)
|
src_tokens = encoded_src["input_ids"].to(self.device)
|
||||||
src_mask = encoded_src['attention_mask'].to(self.device)
|
src_mask = encoded_src["attention_mask"].to(self.device)
|
||||||
|
|
||||||
tgt_tokens = encoded_tgt['input_ids'].to(self.device)[:, 0].unsqueeze(-1)
|
tgt_tokens = encoded_tgt["input_ids"].to(self.device)[:, 0].unsqueeze(-1)
|
||||||
|
|
||||||
output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens)
|
output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens)
|
||||||
logits = output.logits.view(-1, self.model.config.vocab_size)
|
logits = output.logits.view(-1, self.model.config.vocab_size)
|
||||||
|
|
||||||
pos_score = self.softmax(logits)[:, self.pos_id] # Yes
|
pos_score = self.softmax(logits)[:, self.pos_id] # Yes
|
||||||
neg_score = self.softmax(logits)[:, self.neg_id] # No
|
neg_score = self.softmax(logits)[:, self.neg_id] # No
|
||||||
|
|
||||||
cur_pos_score = [x.item() for x in pos_score]
|
cur_pos_score = [x.item() for x in pos_score]
|
||||||
cur_neg_score = [x.item() for x in neg_score]
|
cur_neg_score = [x.item() for x in neg_score]
|
||||||
@ -90,8 +85,8 @@ class UniEvaluator:
|
|||||||
neg_score_list += cur_neg_score
|
neg_score_list += cur_neg_score
|
||||||
|
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
print(f'source: {src_list}')
|
print(f"source: {src_list}")
|
||||||
print(f'target: {tgt_list}')
|
print(f"target: {tgt_list}")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
score_list = []
|
score_list = []
|
||||||
|
@ -31,105 +31,142 @@ import tqdm
|
|||||||
|
|
||||||
def add_question(dimension, output, src=None, ref=None, context=None, task=None):
|
def add_question(dimension, output, src=None, ref=None, context=None, task=None):
|
||||||
"""
|
"""
|
||||||
Add questions to generate input in Bool-QA format for UniEval.
|
Add questions to generate input in Bool-QA format for UniEval.
|
||||||
|
|
||||||
dimension: specific dimension to be evaluated
|
dimension: specific dimension to be evaluated
|
||||||
src: source input for different NLG tasks. For example, source document for summarization
|
src: source input for different NLG tasks. For example, source document for summarization
|
||||||
and dialogue history for dialogue response generation.
|
and dialogue history for dialogue response generation.
|
||||||
output: output text generated by the models
|
output: output text generated by the models
|
||||||
ref: human-annotated groundtruth
|
ref: human-annotated groundtruth
|
||||||
context: the context needed to evaluate several specific dimension. For example,
|
context: the context needed to evaluate several specific dimension. For example,
|
||||||
additional factual information when evaluating engagingness and groundedness in dialogues.
|
additional factual information when evaluating engagingness and groundedness in dialogues.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_with_question = []
|
input_with_question = []
|
||||||
for i in range(len(output)):
|
for i in range(len(output)):
|
||||||
# For summarization
|
# For summarization
|
||||||
if task == 'summarization':
|
if task == "summarization":
|
||||||
if dimension == 'fluency':
|
if dimension == "fluency":
|
||||||
cur_input = 'question: Is this a fluent paragraph? </s> paragraph: ' + output[i]
|
cur_input = "question: Is this a fluent paragraph? </s> paragraph: " + output[i]
|
||||||
elif dimension == 'coherence':
|
elif dimension == "coherence":
|
||||||
cur_input = 'question: Is this a coherent summary to the document? </s> summary: ' + output[
|
cur_input = (
|
||||||
i] + ' </s> document: ' + src[i]
|
"question: Is this a coherent summary to the document? </s> summary: "
|
||||||
elif dimension == 'consistency':
|
+ output[i]
|
||||||
cur_input = 'question: Is this claim consistent with the document? </s> claim: ' + output[
|
+ " </s> document: "
|
||||||
i] + ' </s> document: ' + src[i]
|
+ src[i]
|
||||||
elif dimension == 'relevance':
|
)
|
||||||
cur_input = 'question: Is this summary relevant to the reference? </s> summary: ' + output[
|
elif dimension == "consistency":
|
||||||
i] + ' </s> reference: ' + ref[i]
|
cur_input = (
|
||||||
|
"question: Is this claim consistent with the document? </s> claim: "
|
||||||
|
+ output[i]
|
||||||
|
+ " </s> document: "
|
||||||
|
+ src[i]
|
||||||
|
)
|
||||||
|
elif dimension == "relevance":
|
||||||
|
cur_input = (
|
||||||
|
"question: Is this summary relevant to the reference? </s> summary: "
|
||||||
|
+ output[i]
|
||||||
|
+ " </s> reference: "
|
||||||
|
+ ref[i]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'The input format for this dimension is still undefined. Please customize it first.')
|
"The input format for this dimension is still undefined. Please customize it first."
|
||||||
|
)
|
||||||
# For dialogues
|
# For dialogues
|
||||||
elif task == 'dialogue':
|
elif task == "dialogue":
|
||||||
if dimension == 'naturalness':
|
if dimension == "naturalness":
|
||||||
cur_input = 'question: Is this a natural response in the dialogue? </s> response: ' + output[i]
|
cur_input = "question: Is this a natural response in the dialogue? </s> response: " + output[i]
|
||||||
elif dimension == 'coherence':
|
elif dimension == "coherence":
|
||||||
cur_input = 'question: Is this a coherent response given the dialogue history? </s> response: '\
|
cur_input = (
|
||||||
+ output[i] + ' </s> dialogue history: ' + src[i]
|
"question: Is this a coherent response given the dialogue history? </s> response: "
|
||||||
elif dimension == 'engagingness':
|
+ output[i]
|
||||||
cur_input = 'question: Is this an engaging and informative response according to the dialogue history and fact? </s> response: '\
|
+ " </s> dialogue history: "
|
||||||
+ output[i] + ' </s> dialogue history: ' + src[i] + ' </s> fact: ' + context[i]
|
+ src[i]
|
||||||
elif dimension == 'groundedness':
|
)
|
||||||
cur_input = 'question: Is this response consistent with knowledge in the fact? </s> response: '\
|
elif dimension == "engagingness":
|
||||||
+ output[i] + ' </s> fact: ' + context[i]
|
cur_input = (
|
||||||
elif dimension == 'understandability':
|
"question: Is this an engaging and informative response according to the dialogue history and fact? </s> response: "
|
||||||
cur_input = 'question: Is this an understandable response in the dialogue? </s> response: ' + output[i]
|
+ output[i]
|
||||||
|
+ " </s> dialogue history: "
|
||||||
|
+ src[i]
|
||||||
|
+ " </s> fact: "
|
||||||
|
+ context[i]
|
||||||
|
)
|
||||||
|
elif dimension == "groundedness":
|
||||||
|
cur_input = (
|
||||||
|
"question: Is this response consistent with knowledge in the fact? </s> response: "
|
||||||
|
+ output[i]
|
||||||
|
+ " </s> fact: "
|
||||||
|
+ context[i]
|
||||||
|
)
|
||||||
|
elif dimension == "understandability":
|
||||||
|
cur_input = "question: Is this an understandable response in the dialogue? </s> response: " + output[i]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'The input format for this dimension is still undefined. Please customize it first.')
|
"The input format for this dimension is still undefined. Please customize it first."
|
||||||
|
)
|
||||||
# For data-to-text
|
# For data-to-text
|
||||||
elif task == 'data2text':
|
elif task == "data2text":
|
||||||
if dimension == 'naturalness':
|
if dimension == "naturalness":
|
||||||
cur_input = 'question: Is this a fluent utterance? </s> utterance: ' + output[i]
|
cur_input = "question: Is this a fluent utterance? </s> utterance: " + output[i]
|
||||||
elif dimension == 'informativeness':
|
elif dimension == "informativeness":
|
||||||
cur_input = 'question: Is this sentence informative according to the reference? </s> sentence: '\
|
cur_input = (
|
||||||
+ output[i] + ' </s> reference: ' + ref[i]
|
"question: Is this sentence informative according to the reference? </s> sentence: "
|
||||||
|
+ output[i]
|
||||||
|
+ " </s> reference: "
|
||||||
|
+ ref[i]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'The input format for this dimension is still undefined. Please customize it first.')
|
"The input format for this dimension is still undefined. Please customize it first."
|
||||||
|
)
|
||||||
# For factual consistency detection
|
# For factual consistency detection
|
||||||
elif task == 'fact':
|
elif task == "fact":
|
||||||
if dimension == 'consistency':
|
if dimension == "consistency":
|
||||||
cur_input = 'question: Is this claim consistent with the document? </s> claim: ' + output[
|
cur_input = (
|
||||||
i] + ' </s> document: ' + src[i]
|
"question: Is this claim consistent with the document? </s> claim: "
|
||||||
|
+ output[i]
|
||||||
|
+ " </s> document: "
|
||||||
|
+ src[i]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('No other dimensions for the factual consistency detection task.')
|
raise NotImplementedError("No other dimensions for the factual consistency detection task.")
|
||||||
# For new customized tasks
|
# For new customized tasks
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Other tasks are not implemented, please customize specific tasks here.')
|
raise NotImplementedError("Other tasks are not implemented, please customize specific tasks here.")
|
||||||
input_with_question.append(cur_input)
|
input_with_question.append(cur_input)
|
||||||
return input_with_question
|
return input_with_question
|
||||||
|
|
||||||
|
|
||||||
def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None):
|
def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None):
|
||||||
"""
|
"""
|
||||||
Convert the data into the unieval's format.
|
Convert the data into the unieval's format.
|
||||||
|
|
||||||
output_list: a list of model output
|
output_list: a list of model output
|
||||||
|
|
||||||
src_list: source input for different NLG tasks. For example, source document for summarization
|
src_list: source input for different NLG tasks. For example, source document for summarization
|
||||||
and dialogue history for dialogue response generation
|
and dialogue history for dialogue response generation
|
||||||
ref_list: human-annotated groundtruth
|
ref_list: human-annotated groundtruth
|
||||||
"""
|
"""
|
||||||
json_data = []
|
json_data = []
|
||||||
for i in range(len(output_list)):
|
for i in range(len(output_list)):
|
||||||
cur = {}
|
cur = {}
|
||||||
cur['system_output'] = output_list[i]
|
cur["system_output"] = output_list[i]
|
||||||
if src_list is not None:
|
if src_list is not None:
|
||||||
cur['source'] = src_list[i]
|
cur["source"] = src_list[i]
|
||||||
if ref_list is not None:
|
if ref_list is not None:
|
||||||
cur['reference'] = ref_list[i]
|
cur["reference"] = ref_list[i]
|
||||||
cur['context'] = ""
|
cur["context"] = ""
|
||||||
json_data.append(cur)
|
json_data.append(cur)
|
||||||
return json_data
|
return json_data
|
||||||
|
|
||||||
|
|
||||||
def calculate_average_score(scores):
|
def calculate_average_score(scores):
|
||||||
"""
|
"""
|
||||||
Calculate average scores for different metrics
|
Calculate average scores for different metrics
|
||||||
|
|
||||||
scores: a list of scores for different metrics for each answer
|
scores: a list of scores for different metrics for each answer
|
||||||
|
|
||||||
"""
|
"""
|
||||||
metrics = {metric: 0 for metric in scores[0]}
|
metrics = {metric: 0 for metric in scores[0]}
|
||||||
@ -226,9 +263,9 @@ def analyze_unieval_results(results_path: str, save_path: str) -> None:
|
|||||||
frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv"))
|
frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv"))
|
||||||
|
|
||||||
for metric in tqdm.tqdm(
|
for metric in tqdm.tqdm(
|
||||||
frame_per_metric.keys(),
|
frame_per_metric.keys(),
|
||||||
desc=f"UniEval metrics: ",
|
desc=f"UniEval metrics: ",
|
||||||
total=len(frame_per_metric.keys()),
|
total=len(frame_per_metric.keys()),
|
||||||
):
|
):
|
||||||
data = pd.DataFrame(frame_per_metric[metric])
|
data = pd.DataFrame(frame_per_metric[metric])
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import string
|
import string
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
@ -55,7 +54,7 @@ def jload(f, mode="r"):
|
|||||||
|
|
||||||
|
|
||||||
def get_json_list(file_path):
|
def get_json_list(file_path):
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, "r") as f:
|
||||||
json_list = []
|
json_list = []
|
||||||
for line in f:
|
for line in f:
|
||||||
json_list.append(json.loads(line))
|
json_list.append(json.loads(line))
|
||||||
@ -187,9 +186,9 @@ def analyze_automatic_results(results_path: str, save_path: str) -> None:
|
|||||||
frame_all.to_csv(os.path.join(save_path, "automatic_evaluation_statistics.csv"))
|
frame_all.to_csv(os.path.join(save_path, "automatic_evaluation_statistics.csv"))
|
||||||
|
|
||||||
for metric in tqdm.tqdm(
|
for metric in tqdm.tqdm(
|
||||||
frame_per_metric.keys(),
|
frame_per_metric.keys(),
|
||||||
desc=f"automatic metrics: ",
|
desc=f"automatic metrics: ",
|
||||||
total=len(frame_per_metric.keys()),
|
total=len(frame_per_metric.keys()),
|
||||||
):
|
):
|
||||||
data = pd.DataFrame(frame_per_metric[metric])
|
data = pd.DataFrame(frame_per_metric[metric])
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ import json
|
|||||||
from typing import Dict, Sequence
|
from typing import Dict, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i
|
|||||||
padding="longest",
|
padding="longest",
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
) for text in strings
|
)
|
||||||
|
for text in strings
|
||||||
]
|
]
|
||||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||||
input_ids_lens = labels_lens = [
|
input_ids_lens = labels_lens = [
|
||||||
@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo
|
|||||||
|
|
||||||
|
|
||||||
class EasySupervisedDataset(Dataset):
|
class EasySupervisedDataset(Dataset):
|
||||||
|
|
||||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
|
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
|
||||||
super(EasySupervisedDataset, self).__init__()
|
super(EasySupervisedDataset, self).__init__()
|
||||||
with open(data_file, "r", encoding="UTF-8") as f:
|
with open(data_file, "r", encoding="UTF-8") as f:
|
||||||
all_lines = f.readlines()
|
all_lines = f.readlines()
|
||||||
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
# split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
||||||
sources, targets = [], []
|
sources, targets = [], []
|
||||||
for line in all_lines:
|
for line in all_lines:
|
||||||
if "回答:" in line:
|
if "回答:" in line:
|
||||||
sep_index = line.index("回答:")
|
sep_index = line.index("回答:")
|
||||||
sources.append(line[:sep_index + 3])
|
sources.append(line[: sep_index + 3])
|
||||||
targets.append(line[sep_index + 3:] + tokenizer.eos_token)
|
targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
|
||||||
else:
|
else:
|
||||||
sources.append(line)
|
sources.append(line)
|
||||||
targets.append("" + tokenizer.eos_token)
|
targets.append("" + tokenizer.eos_token)
|
||||||
@ -83,15 +82,17 @@ class EasySupervisedDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class EasyPromptsDataset(Dataset):
|
class EasyPromptsDataset(Dataset):
|
||||||
|
|
||||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
|
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
|
||||||
super(EasyPromptsDataset, self).__init__()
|
super(EasyPromptsDataset, self).__init__()
|
||||||
with open(data_file, "r", encoding="UTF-8") as f:
|
with open(data_file, "r", encoding="UTF-8") as f:
|
||||||
all_lines = f.readlines()
|
all_lines = f.readlines()
|
||||||
all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
|
all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
|
||||||
self.prompts = [
|
self.prompts = [
|
||||||
tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
|
tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
|
||||||
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
|
"input_ids"
|
||||||
|
]
|
||||||
|
.to(torch.cuda.current_device())
|
||||||
|
.squeeze(0)
|
||||||
for line in tqdm(all_lines)
|
for line in tqdm(all_lines)
|
||||||
]
|
]
|
||||||
self.data_file = data_file
|
self.data_file = data_file
|
||||||
@ -110,7 +111,6 @@ class EasyPromptsDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class EasyRewardDataset(Dataset):
|
class EasyRewardDataset(Dataset):
|
||||||
|
|
||||||
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
|
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
|
||||||
super(EasyRewardDataset, self).__init__()
|
super(EasyRewardDataset, self).__init__()
|
||||||
self.chosen = []
|
self.chosen = []
|
||||||
@ -120,44 +120,42 @@ class EasyRewardDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
self.end_token = special_token
|
self.end_token = special_token
|
||||||
print(self.end_token)
|
print(self.end_token)
|
||||||
#read all lines in the train_file to a list
|
# read all lines in the train_file to a list
|
||||||
with open(train_file, "r", encoding="UTF-8") as f:
|
with open(train_file, "r", encoding="UTF-8") as f:
|
||||||
all_lines = f.readlines()
|
all_lines = f.readlines()
|
||||||
for line in tqdm(all_lines):
|
for line in tqdm(all_lines):
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
prompt = "提问:" + data['prompt'] + " 回答:"
|
prompt = "提问:" + data["prompt"] + " 回答:"
|
||||||
|
|
||||||
chosen = prompt + data['chosen'] + self.end_token
|
chosen = prompt + data["chosen"] + self.end_token
|
||||||
chosen_token = tokenizer(chosen,
|
chosen_token = tokenizer(
|
||||||
max_length=max_length,
|
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||||
padding="max_length",
|
)
|
||||||
truncation=True,
|
self.chosen.append(
|
||||||
return_tensors="pt")
|
{"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
|
||||||
self.chosen.append({
|
)
|
||||||
"input_ids": chosen_token['input_ids'],
|
|
||||||
"attention_mask": chosen_token['attention_mask']
|
|
||||||
})
|
|
||||||
|
|
||||||
reject = prompt + data['rejected'] + self.end_token
|
reject = prompt + data["rejected"] + self.end_token
|
||||||
reject_token = tokenizer(reject,
|
reject_token = tokenizer(
|
||||||
max_length=max_length,
|
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||||
padding="max_length",
|
)
|
||||||
truncation=True,
|
self.reject.append(
|
||||||
return_tensors="pt")
|
{"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
|
||||||
self.reject.append({
|
)
|
||||||
"input_ids": reject_token['input_ids'],
|
|
||||||
"attention_mask": reject_token['attention_mask']
|
|
||||||
})
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
length = len(self.chosen)
|
length = len(self.chosen)
|
||||||
return length
|
return length
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
return (
|
||||||
"input_ids"], self.reject[idx]["attention_mask"]
|
self.chosen[idx]["input_ids"],
|
||||||
|
self.chosen[idx]["attention_mask"],
|
||||||
|
self.reject[idx]["input_ids"],
|
||||||
|
self.reject[idx]["attention_mask"],
|
||||||
|
)
|
||||||
|
|
||||||
#python representation of the object and the string representation of the object
|
# python representation of the object and the string representation of the object
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||||
|
|
||||||
@ -165,26 +163,25 @@ class EasyRewardDataset(Dataset):
|
|||||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||||
|
|
||||||
|
|
||||||
'''
|
"""
|
||||||
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
|
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
|
||||||
If individual lines are not related, just set is_group_texts to False.
|
If individual lines are not related, just set is_group_texts to False.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
class EasySFTDataset(Dataset):
|
class EasySFTDataset(Dataset):
|
||||||
|
|
||||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
|
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
#read the data_file line by line
|
# read the data_file line by line
|
||||||
with open(data_file, "r", encoding="UTF-8") as f:
|
with open(data_file, "r", encoding="UTF-8") as f:
|
||||||
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
# encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
||||||
raw_input_ids = []
|
raw_input_ids = []
|
||||||
for line in f:
|
for line in f:
|
||||||
encoded_ids = tokenizer.encode(line)
|
encoded_ids = tokenizer.encode(line)
|
||||||
#if the encoded_ids is longer than max_length, then split it into several parts
|
# if the encoded_ids is longer than max_length, then split it into several parts
|
||||||
if len(encoded_ids) > max_length:
|
if len(encoded_ids) > max_length:
|
||||||
for i in range(0, len(encoded_ids), max_length):
|
for i in range(0, len(encoded_ids), max_length):
|
||||||
raw_input_ids.append(encoded_ids[i:i + max_length])
|
raw_input_ids.append(encoded_ids[i : i + max_length])
|
||||||
else:
|
else:
|
||||||
raw_input_ids.append(encoded_ids)
|
raw_input_ids.append(encoded_ids)
|
||||||
|
|
||||||
@ -196,12 +193,13 @@ class EasySFTDataset(Dataset):
|
|||||||
if is_group_texts:
|
if is_group_texts:
|
||||||
for input_ids in raw_input_ids:
|
for input_ids in raw_input_ids:
|
||||||
if len(current_input_ids) + len(input_ids) > max_length:
|
if len(current_input_ids) + len(input_ids) > max_length:
|
||||||
#pad the current_input_ids to max_length with tokenizer.pad_token_id
|
# pad the current_input_ids to max_length with tokenizer.pad_token_id
|
||||||
padded_length = max_length - len(current_input_ids)
|
padded_length = max_length - len(current_input_ids)
|
||||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||||
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||||
attention_mask.append(
|
attention_mask.append(
|
||||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||||
|
)
|
||||||
current_input_ids = []
|
current_input_ids = []
|
||||||
else:
|
else:
|
||||||
current_input_ids.extend(input_ids)
|
current_input_ids.extend(input_ids)
|
||||||
@ -210,14 +208,16 @@ class EasySFTDataset(Dataset):
|
|||||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||||
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||||
attention_mask.append(
|
attention_mask.append(
|
||||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
#just append the raw_input_ids to max_length
|
# just append the raw_input_ids to max_length
|
||||||
for input_ids in raw_input_ids:
|
for input_ids in raw_input_ids:
|
||||||
padded_length = max_length - len(input_ids)
|
padded_length = max_length - len(input_ids)
|
||||||
input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||||
attention_mask.append(
|
attention_mask.append(
|
||||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||||
|
)
|
||||||
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
|
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
|
||||||
self.input_ids = grouped_input_ids
|
self.input_ids = grouped_input_ids
|
||||||
self.labels = copy.deepcopy(self.input_ids)
|
self.labels = copy.deepcopy(self.input_ids)
|
||||||
@ -227,14 +227,14 @@ class EasySFTDataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.input_ids)
|
return len(self.input_ids)
|
||||||
|
|
||||||
#get item from dataset
|
# get item from dataset
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
||||||
|
|
||||||
#generate the dataset description to be printed by print in python
|
# generate the dataset description to be printed by print in python
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||||
|
|
||||||
#generate the dataset description to be printed by print in python
|
# generate the dataset description to be printed by print in python
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from coati.models.generation import generate
|
from coati.models.generation import generate
|
||||||
from coati.models.utils import log_probs_from_logits, masked_mean
|
from coati.models.utils import log_probs_from_logits
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from torch.nn.modules import Module
|
from torch.nn.modules import Module
|
||||||
from transformers import BloomConfig, BloomForCausalLM
|
from transformers import BloomConfig, BloomForCausalLM
|
||||||
@ -24,38 +24,33 @@ class Actor(Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs
|
||||||
input_ids: torch.Tensor,
|
|
||||||
return_action_mask: bool = True,
|
|
||||||
**kwargs
|
|
||||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||||
sequences = generate(self.model, input_ids, **kwargs)
|
sequences = generate(self.model, input_ids, **kwargs)
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
pad_token_id = kwargs.get('pad_token_id', None)
|
pad_token_id = kwargs.get("pad_token_id", None)
|
||||||
if pad_token_id is not None:
|
if pad_token_id is not None:
|
||||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||||
if not return_action_mask:
|
if not return_action_mask:
|
||||||
return sequences, attention_mask, None
|
return sequences, attention_mask, None
|
||||||
input_len = input_ids.size(1)
|
input_len = input_ids.size(1)
|
||||||
eos_token_id = kwargs.get('eos_token_id', None)
|
eos_token_id = kwargs.get("eos_token_id", None)
|
||||||
if eos_token_id is None:
|
if eos_token_id is None:
|
||||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||||
else:
|
else:
|
||||||
# left padding may be applied, only mask action
|
# left padding may be applied, only mask action
|
||||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||||
action_mask[:, :input_len] = False
|
action_mask[:, :input_len] = False
|
||||||
action_mask = action_mask[:, 1:]
|
action_mask = action_mask[:, 1:]
|
||||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :]
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
sequences: torch.LongTensor,
|
self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None
|
||||||
num_actions: int,
|
) -> torch.Tensor:
|
||||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
"""Returns action log probs"""
|
||||||
"""Returns action log probs
|
|
||||||
"""
|
|
||||||
output = self.model(sequences, attention_mask=attention_mask)
|
output = self.model(sequences, attention_mask=attention_mask)
|
||||||
logits = output['logits']
|
logits = output["logits"]
|
||||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||||
return log_probs[:, -num_actions:]
|
return log_probs[:, -num_actions:]
|
||||||
|
|
||||||
@ -75,11 +70,13 @@ class BLOOMActor(Actor):
|
|||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pretrained: str = None,
|
self,
|
||||||
config: Optional[BloomConfig] = None,
|
pretrained: str = None,
|
||||||
checkpoint: bool = False,
|
config: Optional[BloomConfig] = None,
|
||||||
lora_path: str = None) -> None:
|
checkpoint: bool = False,
|
||||||
|
lora_path: str = None,
|
||||||
|
) -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -1,18 +1,16 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
|
from coati.dataset import DataCollatorForSupervisedDataset
|
||||||
from coati.models.bloom import BLOOMRM, BLOOMCritic
|
from coati.models.bloom import BLOOMRM, BLOOMCritic
|
||||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
from coati.models.gpt import GPTRM, GPTCritic
|
||||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
from coati.models.llama import LlamaCritic, LlamaRM
|
||||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
from coati.models.opt import OPTRM, OPTCritic
|
||||||
from coati.trainer import PPOTrainer
|
from coati.trainer import PPOTrainer
|
||||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||||
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
|
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
|
||||||
from easy_models import BLOOMActor
|
from easy_models import BLOOMActor
|
||||||
from peft import PeftModel
|
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
@ -23,24 +21,24 @@ from colossalai.nn.optimizer import HybridAdam
|
|||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# configure strategy
|
# configure strategy
|
||||||
if args.strategy == 'ddp':
|
if args.strategy == "ddp":
|
||||||
strategy = DDPStrategy()
|
strategy = DDPStrategy()
|
||||||
elif args.strategy == 'colossalai_gemini':
|
elif args.strategy == "colossalai_gemini":
|
||||||
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
|
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||||
elif args.strategy == 'colossalai_zero2':
|
elif args.strategy == "colossalai_zero2":
|
||||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||||
|
|
||||||
if args.rm_path is not None:
|
if args.rm_path is not None:
|
||||||
state_dict = torch.load(args.rm_path, map_location='cpu')
|
state_dict = torch.load(args.rm_path, map_location="cpu")
|
||||||
|
|
||||||
# configure model
|
# configure model
|
||||||
if args.model == 'bloom':
|
if args.model == "bloom":
|
||||||
# initial_model = BLOOMActor(pretrained=args.pretrain)
|
# initial_model = BLOOMActor(pretrained=args.pretrain)
|
||||||
print('Using peft lora to load Bloom model as initial_model')
|
print("Using peft lora to load Bloom model as initial_model")
|
||||||
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
||||||
print('Using peft lora to load Bloom model as initial_model (Done)')
|
print("Using peft lora to load Bloom model as initial_model (Done)")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||||
|
|
||||||
@ -49,59 +47,59 @@ def main(args):
|
|||||||
else:
|
else:
|
||||||
rm_model_name = args.rm_model
|
rm_model_name = args.rm_model
|
||||||
|
|
||||||
if rm_model_name == 'gpt2':
|
if rm_model_name == "gpt2":
|
||||||
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
||||||
elif rm_model_name == 'bloom':
|
elif rm_model_name == "bloom":
|
||||||
print("load bloom reward model ", args.rm_pretrain)
|
print("load bloom reward model ", args.rm_pretrain)
|
||||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
||||||
elif rm_model_name == 'opt':
|
elif rm_model_name == "opt":
|
||||||
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
||||||
elif rm_model_name == 'llama':
|
elif rm_model_name == "llama":
|
||||||
reward_model = LlamaRM(pretrained=args.rm_pretrain)
|
reward_model = LlamaRM(pretrained=args.rm_pretrain)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||||
|
|
||||||
if args.rm_path is not None:
|
if args.rm_path is not None:
|
||||||
print('Loading reward model from', args.rm_path)
|
print("Loading reward model from", args.rm_path)
|
||||||
reward_model.load_state_dict(state_dict)
|
reward_model.load_state_dict(state_dict)
|
||||||
|
|
||||||
if args.strategy != 'colossalai_gemini':
|
if args.strategy != "colossalai_gemini":
|
||||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||||
|
|
||||||
with strategy.model_init_context():
|
with strategy.model_init_context():
|
||||||
if args.model == 'bloom':
|
if args.model == "bloom":
|
||||||
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||||
print('Using peft lora to load Bloom model as Actor')
|
print("Using peft lora to load Bloom model as Actor")
|
||||||
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
||||||
print('Using peft lora to load Bloom model as Actor (Done)')
|
print("Using peft lora to load Bloom model as Actor (Done)")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||||
|
|
||||||
if rm_model_name == 'gpt2':
|
if rm_model_name == "gpt2":
|
||||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||||
elif rm_model_name == 'bloom':
|
elif rm_model_name == "bloom":
|
||||||
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
|
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
|
||||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||||
print("load bloom critic (Done) ")
|
print("load bloom critic (Done) ")
|
||||||
elif rm_model_name == 'opt':
|
elif rm_model_name == "opt":
|
||||||
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||||
elif rm_model_name == 'llama':
|
elif rm_model_name == "llama":
|
||||||
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||||
|
|
||||||
if args.rm_path is not None:
|
if args.rm_path is not None:
|
||||||
print('Loading reward model from', args.rm_path)
|
print("Loading reward model from", args.rm_path)
|
||||||
critic.load_state_dict(state_dict)
|
critic.load_state_dict(state_dict)
|
||||||
del state_dict
|
del state_dict
|
||||||
|
|
||||||
if args.strategy != 'colossalai_gemini':
|
if args.strategy != "colossalai_gemini":
|
||||||
critic.to(torch.float16).to(torch.cuda.current_device())
|
critic.to(torch.float16).to(torch.cuda.current_device())
|
||||||
actor.to(torch.float16).to(torch.cuda.current_device())
|
actor.to(torch.float16).to(torch.cuda.current_device())
|
||||||
|
|
||||||
# configure optimizer
|
# configure optimizer
|
||||||
if args.strategy.startswith('colossalai'):
|
if args.strategy.startswith("colossalai"):
|
||||||
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
|
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
|
||||||
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
|
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
|
||||||
else:
|
else:
|
||||||
@ -109,18 +107,18 @@ def main(args):
|
|||||||
critic_optim = Adam(critic.parameters(), lr=1e-7)
|
critic_optim = Adam(critic.parameters(), lr=1e-7)
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
if args.model == 'gpt2':
|
if args.model == "gpt2":
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
|
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
elif args.model == 'bloom':
|
elif args.model == "bloom":
|
||||||
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
|
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
elif args.model == 'opt':
|
elif args.model == "opt":
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
|
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
elif args.model == 'llama':
|
elif args.model == "llama":
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
|
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
|
||||||
tokenizer.eos_token = '<\s>'
|
tokenizer.eos_token = "<\s>"
|
||||||
tokenizer.pad_token = tokenizer.unk_token
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
@ -132,26 +130,27 @@ def main(args):
|
|||||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
||||||
else:
|
else:
|
||||||
prompt_sampler = None
|
prompt_sampler = None
|
||||||
prompt_dataloader = DataLoader(prompt_dataset,
|
prompt_dataloader = DataLoader(
|
||||||
shuffle=(prompt_sampler is None),
|
prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size
|
||||||
sampler=prompt_sampler,
|
)
|
||||||
batch_size=args.train_batch_size)
|
|
||||||
|
|
||||||
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
|
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
|
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
|
||||||
else:
|
else:
|
||||||
pretrain_sampler = None
|
pretrain_sampler = None
|
||||||
pretrain_dataloader = DataLoader(pretrain_dataset,
|
pretrain_dataloader = DataLoader(
|
||||||
shuffle=(pretrain_sampler is None),
|
pretrain_dataset,
|
||||||
sampler=pretrain_sampler,
|
shuffle=(pretrain_sampler is None),
|
||||||
batch_size=args.ptx_batch_size,
|
sampler=pretrain_sampler,
|
||||||
collate_fn=data_collator)
|
batch_size=args.ptx_batch_size,
|
||||||
|
collate_fn=data_collator,
|
||||||
|
)
|
||||||
|
|
||||||
def tokenize_fn(texts):
|
def tokenize_fn(texts):
|
||||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||||
# Different length may lead to hang when using gemini, as different generation steps
|
# Different length may lead to hang when using gemini, as different generation steps
|
||||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
|
||||||
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
|
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
|
||||||
|
|
||||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||||
@ -178,45 +177,46 @@ def main(args):
|
|||||||
eos_token_id=tokenizer.eos_token_id,
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
trainer.fit(
|
||||||
pretrain_dataloader=pretrain_dataloader,
|
prompt_dataloader=prompt_dataloader,
|
||||||
num_episodes=args.num_episodes,
|
pretrain_dataloader=pretrain_dataloader,
|
||||||
num_update_steps=args.num_update_steps,
|
num_episodes=args.num_episodes,
|
||||||
num_collect_steps=args.num_collect_steps)
|
num_update_steps=args.num_update_steps,
|
||||||
|
num_collect_steps=args.num_collect_steps,
|
||||||
|
)
|
||||||
|
|
||||||
# save model checkpoint after fitting
|
# save model checkpoint after fitting
|
||||||
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
|
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||||
# save optimizer checkpoint on all ranks
|
# save optimizer checkpoint on all ranks
|
||||||
if args.need_optim_ckpt:
|
if args.need_optim_ckpt:
|
||||||
strategy.save_optimizer(actor_optim,
|
strategy.save_optimizer(
|
||||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
|
||||||
only_rank0=False)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset')
|
parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset")
|
||||||
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
|
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
|
||||||
parser.add_argument('--strategy',
|
parser.add_argument(
|
||||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
"--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use"
|
||||||
default='ddp',
|
)
|
||||||
help='strategy to use')
|
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
parser.add_argument("--pretrain", type=str, default=None)
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
parser.add_argument("--sft_lora_path", type=str, default=None)
|
||||||
parser.add_argument('--sft_lora_path', type=str, default=None)
|
parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
|
||||||
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
|
parser.add_argument("--rm_path", type=str, default=None)
|
||||||
parser.add_argument('--rm_path', type=str, default=None)
|
parser.add_argument("--rm_pretrain", type=str, default=None)
|
||||||
parser.add_argument('--rm_pretrain', type=str, default=None)
|
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
|
||||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
parser.add_argument("--num_episodes", type=int, default=10)
|
||||||
parser.add_argument('--num_episodes', type=int, default=10)
|
parser.add_argument("--num_collect_steps", type=int, default=10)
|
||||||
parser.add_argument('--num_collect_steps', type=int, default=10)
|
parser.add_argument("--num_update_steps", type=int, default=5)
|
||||||
parser.add_argument('--num_update_steps', type=int, default=5)
|
parser.add_argument("--train_batch_size", type=int, default=2)
|
||||||
parser.add_argument('--train_batch_size', type=int, default=2)
|
parser.add_argument("--ptx_batch_size", type=int, default=1)
|
||||||
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||||
parser.add_argument('--kl_coef', type=float, default=0.1)
|
parser.add_argument("--ptx_coef", type=float, default=0.9)
|
||||||
parser.add_argument('--ptx_coef', type=float, default=0.9)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -1,18 +1,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import loralib as lora
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
|
|
||||||
from coati.models.base import RewardModel
|
|
||||||
from coati.models.bloom import BLOOMLM
|
|
||||||
from coati.models.gpt import GPTLM
|
|
||||||
from coati.models.llama import LlamaLM
|
|
||||||
from coati.models.opt import OPTLM
|
|
||||||
from coati.trainer import SFTTrainer
|
from coati.trainer import SFTTrainer
|
||||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||||
from datasets import load_dataset
|
|
||||||
from easy_dataset import EasyDataset
|
from easy_dataset import EasyDataset
|
||||||
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
@ -29,75 +21,76 @@ from colossalai.tensor import ColoParameter
|
|||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
# configure strategy
|
# configure strategy
|
||||||
if args.strategy == 'ddp':
|
if args.strategy == "ddp":
|
||||||
strategy = DDPStrategy()
|
strategy = DDPStrategy()
|
||||||
elif args.strategy == 'colossalai_gemini':
|
elif args.strategy == "colossalai_gemini":
|
||||||
strategy = GeminiStrategy(placement_policy='cuda')
|
strategy = GeminiStrategy(placement_policy="cuda")
|
||||||
elif args.strategy == 'colossalai_zero2':
|
elif args.strategy == "colossalai_zero2":
|
||||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||||
|
|
||||||
# configure model
|
# configure model
|
||||||
with strategy.model_init_context():
|
with strategy.model_init_context():
|
||||||
print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested')
|
print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested")
|
||||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
|
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
|
||||||
# if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
|
# if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
|
||||||
if os.path.exists(args.save_path) and os.path.exists(args.save_path + '/adapter_config.json') \
|
if (
|
||||||
and os.path.exists(args.save_path + '/adapter_model.bin'):
|
os.path.exists(args.save_path)
|
||||||
|
and os.path.exists(args.save_path + "/adapter_config.json")
|
||||||
|
and os.path.exists(args.save_path + "/adapter_model.bin")
|
||||||
|
):
|
||||||
print("loading from saved peft model ", args.save_path)
|
print("loading from saved peft model ", args.save_path)
|
||||||
model = PeftModel.from_pretrained(model, args.save_path)
|
model = PeftModel.from_pretrained(model, args.save_path)
|
||||||
else:
|
else:
|
||||||
# we'll use peft lora library to do the lora
|
# we'll use peft lora library to do the lora
|
||||||
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
|
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
|
||||||
# config lora with rank of lora_rank
|
# config lora with rank of lora_rank
|
||||||
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
|
lora_config = LoraConfig(
|
||||||
inference_mode=False,
|
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1
|
||||||
r=lora_rank,
|
)
|
||||||
lora_alpha=32,
|
|
||||||
lora_dropout=0.1)
|
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
if args.model == 'gpt2':
|
if args.model == "gpt2":
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
elif args.model == 'bloom':
|
elif args.model == "bloom":
|
||||||
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
|
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
elif args.model == 'opt':
|
elif args.model == "opt":
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
elif args.model == 'llama':
|
elif args.model == "llama":
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.pretrain,
|
args.pretrain,
|
||||||
padding_side="right",
|
padding_side="right",
|
||||||
use_fast=False,
|
use_fast=False,
|
||||||
)
|
)
|
||||||
tokenizer.eos_token = '<\s>'
|
tokenizer.eos_token = "<\s>"
|
||||||
tokenizer.pad_token = tokenizer.unk_token
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
if args.model == 'llama' and args.strategy == 'colossalai_gemini':
|
if args.model == "llama" and args.strategy == "colossalai_gemini":
|
||||||
# this is a hack to deal with the resized embedding
|
# this is a hack to deal with the resized embedding
|
||||||
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
|
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if not isinstance(param, ColoParameter):
|
if not isinstance(param, ColoParameter):
|
||||||
sub_module_name = '.'.join(name.split('.')[:-1])
|
sub_module_name = ".".join(name.split(".")[:-1])
|
||||||
weight_name = name.split('.')[-1]
|
weight_name = name.split(".")[-1]
|
||||||
sub_module = model.get_submodule(sub_module_name)
|
sub_module = model.get_submodule(sub_module_name)
|
||||||
setattr(sub_module, weight_name, ColoParameter(param))
|
setattr(sub_module, weight_name, ColoParameter(param))
|
||||||
|
|
||||||
# configure optimizer
|
# configure optimizer
|
||||||
if args.strategy.startswith('colossalai'):
|
if args.strategy.startswith("colossalai"):
|
||||||
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
||||||
else:
|
else:
|
||||||
optim = Adam(model.parameters(), lr=args.lr)
|
optim = Adam(model.parameters(), lr=args.lr)
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
logger.set_level('WARNING')
|
logger.set_level("WARNING")
|
||||||
|
|
||||||
# configure dataset
|
# configure dataset
|
||||||
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
||||||
@ -108,47 +101,57 @@ def train(args):
|
|||||||
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
||||||
data_collator = default_collate
|
data_collator = default_collate
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
train_sampler = DistributedSampler(train_dataset,
|
train_sampler = DistributedSampler(
|
||||||
shuffle=True,
|
train_dataset,
|
||||||
seed=42,
|
shuffle=True,
|
||||||
drop_last=True,
|
seed=42,
|
||||||
rank=dist.get_rank(),
|
drop_last=True,
|
||||||
num_replicas=dist.get_world_size())
|
rank=dist.get_rank(),
|
||||||
|
num_replicas=dist.get_world_size(),
|
||||||
|
)
|
||||||
if eval_dataset is not None:
|
if eval_dataset is not None:
|
||||||
eval_sampler = DistributedSampler(eval_dataset,
|
eval_sampler = DistributedSampler(
|
||||||
shuffle=False,
|
eval_dataset,
|
||||||
seed=42,
|
shuffle=False,
|
||||||
drop_last=False,
|
seed=42,
|
||||||
rank=dist.get_rank(),
|
drop_last=False,
|
||||||
num_replicas=dist.get_world_size())
|
rank=dist.get_rank(),
|
||||||
|
num_replicas=dist.get_world_size(),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
train_sampler = None
|
train_sampler = None
|
||||||
eval_sampler = None
|
eval_sampler = None
|
||||||
|
|
||||||
train_dataloader = DataLoader(train_dataset,
|
train_dataloader = DataLoader(
|
||||||
shuffle=(train_sampler is None),
|
train_dataset,
|
||||||
sampler=train_sampler,
|
shuffle=(train_sampler is None),
|
||||||
batch_size=args.batch_size,
|
sampler=train_sampler,
|
||||||
collate_fn=data_collator,
|
batch_size=args.batch_size,
|
||||||
pin_memory=True)
|
collate_fn=data_collator,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
if eval_dataset is not None:
|
if eval_dataset is not None:
|
||||||
eval_dataloader = DataLoader(eval_dataset,
|
eval_dataloader = DataLoader(
|
||||||
shuffle=(eval_sampler is None),
|
eval_dataset,
|
||||||
sampler=eval_sampler,
|
shuffle=(eval_sampler is None),
|
||||||
batch_size=args.batch_size,
|
sampler=eval_sampler,
|
||||||
collate_fn=data_collator,
|
batch_size=args.batch_size,
|
||||||
pin_memory=True)
|
collate_fn=data_collator,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
eval_dataloader = None
|
eval_dataloader = None
|
||||||
|
|
||||||
trainer = SFTTrainer(model=model,
|
trainer = SFTTrainer(
|
||||||
strategy=strategy,
|
model=model,
|
||||||
optim=optim,
|
strategy=strategy,
|
||||||
train_dataloader=train_dataloader,
|
optim=optim,
|
||||||
eval_dataloader=eval_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
batch_size=args.batch_size,
|
eval_dataloader=eval_dataloader,
|
||||||
max_epochs=args.max_epochs,
|
batch_size=args.batch_size,
|
||||||
accumulation_steps=args.accumulation_steps)
|
max_epochs=args.max_epochs,
|
||||||
|
accumulation_steps=args.accumulation_steps,
|
||||||
|
)
|
||||||
|
|
||||||
trainer.fit(logger=logger, log_interval=args.log_interval)
|
trainer.fit(logger=logger, log_interval=args.log_interval)
|
||||||
|
|
||||||
@ -156,29 +159,27 @@ def train(args):
|
|||||||
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||||
# save optimizer checkpoint on all ranks
|
# save optimizer checkpoint on all ranks
|
||||||
if args.need_optim_ckpt:
|
if args.need_optim_ckpt:
|
||||||
strategy.save_optimizer(trainer.optimizer,
|
strategy.save_optimizer(
|
||||||
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
|
trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
|
||||||
only_rank0=False)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--strategy',
|
parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
|
||||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
|
||||||
default='ddp')
|
parser.add_argument("--pretrain", type=str, default=None)
|
||||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
parser.add_argument("--dataset", type=str, default=None)
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
parser.add_argument("--eval_dataset", type=str, default=None)
|
||||||
parser.add_argument('--dataset', type=str, default=None)
|
parser.add_argument("--save_path", type=str, default="output")
|
||||||
parser.add_argument('--eval_dataset', type=str, default=None)
|
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||||
parser.add_argument('--save_path', type=str, default='output')
|
parser.add_argument("--max_epochs", type=int, default=3)
|
||||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
parser.add_argument("--batch_size", type=int, default=4)
|
||||||
parser.add_argument('--max_epochs', type=int, default=3)
|
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||||
parser.add_argument('--batch_size', type=int, default=4)
|
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument("--lr", type=float, default=5e-6)
|
||||||
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
|
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||||
parser.add_argument('--lr', type=float, default=5e-6)
|
parser.add_argument("--enable_peft_lora", action="store_true", default=False)
|
||||||
parser.add_argument('--accumulation_steps', type=int, default=8)
|
parser.add_argument("--is_short_text", action="store_true", default=False)
|
||||||
parser.add_argument('--enable_peft_lora', action='store_true', default=False)
|
|
||||||
parser.add_argument("--is_short_text", action='store_true', default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user